Tensorium
Loading...
Searching...
No Matches
GemmKernel_bigger.hpp
Go to the documentation of this file.
1#pragma once
2
4#include "../Matrix.hpp"
5#include <algorithm>
6#include <cstdlib>
7#include <cstring>
8#include <immintrin.h>
9/*
10 * this Gemm kernel is based on Aman Salykov version. Improvment of the OMP schedulding and Block
11 * sizes
12 *
13 */
14
15namespace tensorium {
16template <typename T> class GemmKernelBigger {
17 public:
19 using reg = typename Simd::reg;
20 static constexpr int SimdWidth = Simd::width;
21 static constexpr int TileRows = SimdWidth * 2;
22 static constexpr int TileCols = 6;
23 static constexpr int NThreads = 16;
24
25 static constexpr int BlockDepth = 256;
26 static constexpr int BlockRows = 384;
27 static constexpr int BlockCols = 512;
28
29 static inline int8_t mask[32] __attribute__((aligned(64))) = {
30 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
31 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
32
33 inline void fma_loop_00(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01,
34 reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc) {
35
36 for (int p = 0; p < kc; p++) {
37 *a0_packFloat8 = Simd::loadu(blockA_packed);
38 *a1_packFloat8 = Simd::loadu(blockA_packed + 8);
39
40 *b_packFloat8 = Simd::broadcast(blockB_packed);
41 *C_accum_00 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_00);
42 *C_accum_01 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_01);
43
44 blockA_packed += 16;
45 blockB_packed += 6;
46 }
47 }
48
49 inline void fma_loop_01(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01,
50 reg *C_accum_10, reg *C_accum_11, reg *a0_packFloat8,
51 reg *a1_packFloat8, reg *b_packFloat8, int kc) {
52
53 for (int p = 0; p < kc; p++) {
54 *a0_packFloat8 = Simd::loadu(blockA_packed);
55 *a1_packFloat8 = Simd::loadu(blockA_packed + 8);
56
57 *b_packFloat8 = Simd::broadcast(blockB_packed);
58 *C_accum_00 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_00);
59 *C_accum_01 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_01);
60
61 *b_packFloat8 = Simd::broadcast(blockB_packed + 1);
62 *C_accum_10 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_10);
63 *C_accum_11 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_11);
64
65 blockA_packed += 16;
66 blockB_packed += 6;
67 }
68 }
69
70 inline void fma_loop_02(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01,
71 reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
72 reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc) {
73
74 for (int p = 0; p < kc; p++) {
75 *a0_packFloat8 = Simd::loadu(blockA_packed);
76 *a1_packFloat8 = Simd::loadu(blockA_packed + 8);
77
78 *b_packFloat8 = Simd::broadcast(blockB_packed);
79 *C_accum_00 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_00);
80 *C_accum_01 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_01);
81
82 *b_packFloat8 = Simd::broadcast(blockB_packed + 1);
83 *C_accum_10 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_10);
84 *C_accum_11 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_11);
85
86 *b_packFloat8 = Simd::broadcast(blockB_packed + 2);
87 *C_accum_20 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_20);
88 *C_accum_21 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_21);
89
90 blockA_packed += 16;
91 blockB_packed += 6;
92 }
93 }
94
95 inline void fma_loop_03(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01,
96 reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
97 reg *C_accum_30, reg *C_accum_31, reg *a0_packFloat8,
98 reg *a1_packFloat8, reg *b_packFloat8, int kc) {
99
100 for (int p = 0; p < kc; p++) {
101 *a0_packFloat8 = Simd::loadu(blockA_packed);
102 *a1_packFloat8 = Simd::loadu(blockA_packed + 8);
103
104 *b_packFloat8 = Simd::broadcast(blockB_packed);
105 *C_accum_00 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_00);
106 *C_accum_01 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_01);
107
108 *b_packFloat8 = Simd::broadcast(blockB_packed + 1);
109 *C_accum_10 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_10);
110 *C_accum_11 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_11);
111
112 *b_packFloat8 = Simd::broadcast(blockB_packed + 2);
113 *C_accum_20 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_20);
114 *C_accum_21 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_21);
115
116 *b_packFloat8 = Simd::broadcast(blockB_packed + 3);
117 *C_accum_30 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_30);
118 *C_accum_31 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_31);
119
120 blockA_packed += 16;
121 blockB_packed += 6;
122 }
123 }
124
125 inline void fma_loop_04(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01,
126 reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
127 reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41,
128 reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc) {
129
130 for (int p = 0; p < kc; p++) {
131 *a0_packFloat8 = Simd::loadu(blockA_packed);
132 *a1_packFloat8 = Simd::loadu(blockA_packed + 8);
133
134 *b_packFloat8 = Simd::broadcast(blockB_packed);
135 *C_accum_00 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_00);
136 *C_accum_01 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_01);
137
138 *b_packFloat8 = Simd::broadcast(blockB_packed + 1);
139 *C_accum_10 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_10);
140 *C_accum_11 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_11);
141
142 *b_packFloat8 = Simd::broadcast(blockB_packed + 2);
143 *C_accum_20 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_20);
144 *C_accum_21 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_21);
145
146 *b_packFloat8 = Simd::broadcast(blockB_packed + 3);
147 *C_accum_30 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_30);
148 *C_accum_31 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_31);
149
150 *b_packFloat8 = Simd::broadcast(blockB_packed + 4);
151 *C_accum_40 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_40);
152 *C_accum_41 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_41);
153
154 blockA_packed += 16;
155 blockB_packed += 6;
156 }
157 }
158
159 inline void fma_loop_05(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01,
160 reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
161 reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41,
162 reg *C_accum_50, reg *C_accum_51, reg *a0_packFloat8,
163 reg *a1_packFloat8, reg *b_packFloat8, int kc) {
164
165 for (int p = 0; p < kc; p++) {
166 *a0_packFloat8 = Simd::loadu(blockA_packed);
167 *a1_packFloat8 = Simd::loadu(blockA_packed + 8);
168
169 *b_packFloat8 = Simd::broadcast(blockB_packed);
170 *C_accum_00 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_00);
171 *C_accum_01 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_01);
172
173 *b_packFloat8 = Simd::broadcast(blockB_packed + 1);
174 *C_accum_10 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_10);
175 *C_accum_11 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_11);
176
177 *b_packFloat8 = Simd::broadcast(blockB_packed + 2);
178 *C_accum_20 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_20);
179 *C_accum_21 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_21);
180
181 *b_packFloat8 = Simd::broadcast(blockB_packed + 3);
182 *C_accum_30 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_30);
183 *C_accum_31 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_31);
184
185 *b_packFloat8 = Simd::broadcast(blockB_packed + 4);
186 *C_accum_40 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_40);
187 *C_accum_41 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_41);
188
189 *b_packFloat8 = Simd::broadcast(blockB_packed + 5);
190 *C_accum_50 = Simd::fmadd(*a0_packFloat8, *b_packFloat8, *C_accum_50);
191 *C_accum_51 = Simd::fmadd(*a1_packFloat8, *b_packFloat8, *C_accum_51);
192
193 blockA_packed += 16;
194 blockB_packed += 6;
195 }
196 }
197
198 inline static void build_masks(__m256i *packed_mask_0, __m256i *packed_mask_1, int mr) {
199#if defined(__AVX512F__)
200 __m128i m0 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(&mask[32 - mr]));
201 __m128i m1 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(&mask[32 - mr + 16]));
202
203 __m512i p0 = _mm512_cvtepi8_epi32(m0);
204 __m512i p1 = _mm512_cvtepi8_epi32(m1);
205
206 *packed_mask_0 = _mm512_castsi512_si256(p0);
207 *packed_mask_1 = _mm512_castsi512_si256(p1);
208
209#elif defined(__AVX2__)
210 __m128i m0 = _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&mask[16 - mr]));
211 __m128i m1 = _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&mask[16 - mr + 8]));
212
213 *packed_mask_0 = _mm256_cvtepi8_epi32(m0);
214 *packed_mask_1 = _mm256_cvtepi8_epi32(m1);
215#else
216# error "AVX2 or AVX-512 required"
217#endif
218 }
219
220 inline void maskload_accum_00(T *C, reg *C_accum_00, reg *C_accum_01, __m256i packed_mask_0,
221 __m256i packed_mask_1, int M) {
222 *C_accum_00 = Simd::maskload(C, packed_mask_0);
223 *C_accum_01 = Simd::maskload(&C[8], packed_mask_1);
224 }
225
226 inline void maskload_accum_01(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
227 reg *C_accum_11, __m256i packed_mask_0, __m256i packed_mask_1,
228 int M) {
229 *C_accum_00 = Simd::maskload(C, packed_mask_0);
230 *C_accum_01 = Simd::maskload(&C[8], packed_mask_1);
231 *C_accum_10 = Simd::maskload(&C[M], packed_mask_0);
232 *C_accum_11 = Simd::maskload(&C[M + 8], packed_mask_1);
233 }
234
235 inline void maskload_accum_02(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
236 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
237 __m256i packed_mask_0, __m256i packed_mask_1, int M) {
238 *C_accum_00 = Simd::maskload(C, packed_mask_0);
239 *C_accum_01 = Simd::maskload(&C[8], packed_mask_1);
240 *C_accum_10 = Simd::maskload(&C[M], packed_mask_0);
241 *C_accum_11 = Simd::maskload(&C[M + 8], packed_mask_1);
242 *C_accum_20 = Simd::maskload(&C[2 * M], packed_mask_0);
243 *C_accum_21 = Simd::maskload(&C[2 * M + 8], packed_mask_1);
244 }
245
246 inline void maskload_accum_03(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
247 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
248 reg *C_accum_30, reg *C_accum_31, __m256i packed_mask_0,
249 __m256i packed_mask_1, int M) {
250 *C_accum_00 = Simd::maskload(C, packed_mask_0);
251 *C_accum_01 = Simd::maskload(&C[8], packed_mask_1);
252 *C_accum_10 = Simd::maskload(&C[M], packed_mask_0);
253 *C_accum_11 = Simd::maskload(&C[M + 8], packed_mask_1);
254 *C_accum_20 = Simd::maskload(&C[2 * M], packed_mask_0);
255 *C_accum_21 = Simd::maskload(&C[2 * M + 8], packed_mask_1);
256 *C_accum_30 = Simd::maskload(&C[3 * M], packed_mask_0);
257 *C_accum_31 = Simd::maskload(&C[3 * M + 8], packed_mask_1);
258 }
259
260 inline void maskload_accum_04(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
261 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
262 reg *C_accum_30, reg *C_accum_31, reg *C_accum_40,
263 reg *C_accum_41, __m256i packed_mask_0, __m256i packed_mask_1,
264 int M) {
265 *C_accum_00 = Simd::maskload(C, packed_mask_0);
266 *C_accum_01 = Simd::maskload(&C[8], packed_mask_1);
267 *C_accum_10 = Simd::maskload(&C[M], packed_mask_0);
268 *C_accum_11 = Simd::maskload(&C[M + 8], packed_mask_1);
269 *C_accum_20 = Simd::maskload(&C[2 * M], packed_mask_0);
270 *C_accum_21 = Simd::maskload(&C[2 * M + 8], packed_mask_1);
271 *C_accum_30 = Simd::maskload(&C[3 * M], packed_mask_0);
272 *C_accum_31 = Simd::maskload(&C[3 * M + 8], packed_mask_1);
273 *C_accum_40 = Simd::maskload(&C[4 * M], packed_mask_0);
274 *C_accum_41 = Simd::maskload(&C[4 * M + 8], packed_mask_1);
275 }
276
277 inline void maskload_accum_05(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
278 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
279 reg *C_accum_30, reg *C_accum_31, reg *C_accum_40,
280 reg *C_accum_41, reg *C_accum_50, reg *C_accum_51,
281 __m256i packed_mask_0, __m256i packed_mask_1, int M) {
282 *C_accum_00 = Simd::maskload(C, packed_mask_0);
283 *C_accum_01 = Simd::maskload(&C[8], packed_mask_1);
284 *C_accum_10 = Simd::maskload(&C[M], packed_mask_0);
285 *C_accum_11 = Simd::maskload(&C[M + 8], packed_mask_1);
286 *C_accum_20 = Simd::maskload(&C[2 * M], packed_mask_0);
287 *C_accum_21 = Simd::maskload(&C[2 * M + 8], packed_mask_1);
288 *C_accum_30 = Simd::maskload(&C[3 * M], packed_mask_0);
289 *C_accum_31 = Simd::maskload(&C[3 * M + 8], packed_mask_1);
290 *C_accum_40 = Simd::maskload(&C[4 * M], packed_mask_0);
291 *C_accum_41 = Simd::maskload(&C[4 * M + 8], packed_mask_1);
292 *C_accum_50 = Simd::maskload(&C[5 * M], packed_mask_0);
293 *C_accum_51 = Simd::maskload(&C[5 * M + 8], packed_mask_1);
294 }
295
296 inline void load_accum_00(T *C, reg *C_accum_00, reg *C_accum_01, int M) {
297 *C_accum_00 = Simd::loadu(C);
298 *C_accum_01 = Simd::loadu(&C[8]);
299 }
300
301 inline void load_accum_01(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
302 reg *C_accum_11, int M) {
303 *C_accum_00 = Simd::loadu(C);
304 *C_accum_01 = Simd::loadu(&C[8]);
305 *C_accum_10 = Simd::loadu(&C[M]);
306 *C_accum_11 = Simd::loadu(&C[M + 8]);
307 }
308
309 inline void load_accum_02(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
310 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, int M) {
311 *C_accum_00 = Simd::loadu(C);
312 *C_accum_01 = Simd::loadu(&C[8]);
313 *C_accum_10 = Simd::loadu(&C[M]);
314 *C_accum_11 = Simd::loadu(&C[M + 8]);
315 *C_accum_20 = Simd::loadu(&C[2 * M]);
316 *C_accum_21 = Simd::loadu(&C[2 * M + 8]);
317 }
318
319 inline void load_accum_03(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
320 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30,
321 reg *C_accum_31, int M) {
322 *C_accum_00 = Simd::loadu(C);
323 *C_accum_01 = Simd::loadu(&C[8]);
324 *C_accum_10 = Simd::loadu(&C[M]);
325 *C_accum_11 = Simd::loadu(&C[M + 8]);
326 *C_accum_20 = Simd::loadu(&C[2 * M]);
327 *C_accum_21 = Simd::loadu(&C[2 * M + 8]);
328 *C_accum_30 = Simd::loadu(&C[3 * M]);
329 *C_accum_31 = Simd::loadu(&C[3 * M + 8]);
330 }
331
332 inline void load_accum_04(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
333 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30,
334 reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, int M) {
335 *C_accum_00 = Simd::loadu(C);
336 *C_accum_01 = Simd::loadu(&C[8]);
337 *C_accum_10 = Simd::loadu(&C[M]);
338 *C_accum_11 = Simd::loadu(&C[M + 8]);
339 *C_accum_20 = Simd::loadu(&C[2 * M]);
340 *C_accum_21 = Simd::loadu(&C[2 * M + 8]);
341 *C_accum_30 = Simd::loadu(&C[3 * M]);
342 *C_accum_31 = Simd::loadu(&C[3 * M + 8]);
343 *C_accum_40 = Simd::loadu(&C[4 * M]);
344 *C_accum_41 = Simd::loadu(&C[4 * M + 8]);
345 }
346
347 inline void load_accum_05(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
348 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30,
349 reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, reg *C_accum_50,
350 reg *C_accum_51, int M) {
351 *C_accum_00 = Simd::loadu(C);
352 *C_accum_01 = Simd::loadu(&C[8]);
353 *C_accum_10 = Simd::loadu(&C[M]);
354 *C_accum_11 = Simd::loadu(&C[M + 8]);
355 *C_accum_20 = Simd::loadu(&C[2 * M]);
356 *C_accum_21 = Simd::loadu(&C[2 * M + 8]);
357 *C_accum_30 = Simd::loadu(&C[3 * M]);
358 *C_accum_31 = Simd::loadu(&C[3 * M + 8]);
359 *C_accum_40 = Simd::loadu(&C[4 * M]);
360 *C_accum_41 = Simd::loadu(&C[4 * M + 8]);
361 *C_accum_50 = Simd::loadu(&C[5 * M]);
362 *C_accum_51 = Simd::loadu(&C[5 * M + 8]);
363 }
364
365 inline void store_accum_00(T *C, reg *C_accum_00, reg *C_accum_01, int M) {
366 Simd::storeu(C, *C_accum_00);
367 Simd::storeu(&C[8], *C_accum_01);
368 }
369
370 inline void store_accum_01(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
371 reg *C_accum_11, int M) {
372 Simd::storeu(C, *C_accum_00);
373 Simd::storeu(&C[8], *C_accum_01);
374 Simd::storeu(&C[M], *C_accum_10);
375 Simd::storeu(&C[M + 8], *C_accum_11);
376 }
377
378 inline void store_accum_02(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
379 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, int M) {
380 Simd::storeu(C, *C_accum_00);
381 Simd::storeu(&C[8], *C_accum_01);
382 Simd::storeu(&C[M], *C_accum_10);
383 Simd::storeu(&C[M + 8], *C_accum_11);
384 Simd::storeu(&C[2 * M], *C_accum_20);
385 Simd::storeu(&C[2 * M + 8], *C_accum_21);
386 }
387
388 inline void store_accum_03(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
389 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30,
390 reg *C_accum_31, int M) {
391 Simd::storeu(C, *C_accum_00);
392 Simd::storeu(&C[8], *C_accum_01);
393 Simd::storeu(&C[M], *C_accum_10);
394 Simd::storeu(&C[M + 8], *C_accum_11);
395 Simd::storeu(&C[2 * M], *C_accum_20);
396 Simd::storeu(&C[2 * M + 8], *C_accum_21);
397 Simd::storeu(&C[3 * M], *C_accum_30);
398 Simd::storeu(&C[3 * M + 8], *C_accum_31);
399 }
400
401 inline void store_accum_04(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
402 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30,
403 reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, int M) {
404 Simd::storeu(C, *C_accum_00);
405 Simd::storeu(&C[8], *C_accum_01);
406 Simd::storeu(&C[M], *C_accum_10);
407 Simd::storeu(&C[M + 8], *C_accum_11);
408 Simd::storeu(&C[2 * M], *C_accum_20);
409 Simd::storeu(&C[2 * M + 8], *C_accum_21);
410 Simd::storeu(&C[3 * M], *C_accum_30);
411 Simd::storeu(&C[3 * M + 8], *C_accum_31);
412 Simd::storeu(&C[4 * M], *C_accum_40);
413 Simd::storeu(&C[4 * M + 8], *C_accum_41);
414 }
415
416 inline void store_accum_05(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
417 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30,
418 reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, reg *C_accum_50,
419 reg *C_accum_51, int M) {
420 Simd::storeu(C, *C_accum_00);
421 Simd::storeu(&C[8], *C_accum_01);
422 Simd::storeu(&C[M], *C_accum_10);
423 Simd::storeu(&C[M + 8], *C_accum_11);
424 Simd::storeu(&C[2 * M], *C_accum_20);
425 Simd::storeu(&C[2 * M + 8], *C_accum_21);
426 Simd::storeu(&C[3 * M], *C_accum_30);
427 Simd::storeu(&C[3 * M + 8], *C_accum_31);
428 Simd::storeu(&C[4 * M], *C_accum_40);
429 Simd::storeu(&C[4 * M + 8], *C_accum_41);
430 Simd::storeu(&C[5 * M], *C_accum_50);
431 Simd::storeu(&C[5 * M + 8], *C_accum_51);
432 }
433
434 inline void maskstore_accum_00(T *C, reg *C_accum_00, reg *C_accum_01, __m256i packed_mask_0,
435 __m256i packed_mask_1, int M) {
436 Simd::maskstore(C, packed_mask_0, *C_accum_00);
437 Simd::maskstore(&C[8], packed_mask_1, *C_accum_01);
438 }
439
440 inline void maskstore_accum_01(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
441 reg *C_accum_11, __m256i packed_mask_0, __m256i packed_mask_1,
442 int M) {
443 Simd::maskstore(C, packed_mask_0, *C_accum_00);
444 Simd::maskstore(&C[8], packed_mask_1, *C_accum_01);
445 Simd::maskstore(&C[M], packed_mask_0, *C_accum_10);
446 Simd::maskstore(&C[M + 8], packed_mask_1, *C_accum_11);
447 }
448
449 inline void maskstore_accum_02(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
450 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
451 __m256i packed_mask_0, __m256i packed_mask_1, int M) {
452 Simd::maskstore(C, packed_mask_0, *C_accum_00);
453 Simd::maskstore(&C[8], packed_mask_1, *C_accum_01);
454 Simd::maskstore(&C[M], packed_mask_0, *C_accum_10);
455 Simd::maskstore(&C[M + 8], packed_mask_1, *C_accum_11);
456 Simd::maskstore(&C[2 * M], packed_mask_0, *C_accum_20);
457 Simd::maskstore(&C[2 * M + 8], packed_mask_1, *C_accum_21);
458 }
459
460 inline void maskstore_accum_03(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
461 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
462 reg *C_accum_30, reg *C_accum_31, __m256i packed_mask_0,
463 __m256i packed_mask_1, int M) {
464 Simd::maskstore(C, packed_mask_0, *C_accum_00);
465 Simd::maskstore(&C[8], packed_mask_1, *C_accum_01);
466 Simd::maskstore(&C[M], packed_mask_0, *C_accum_10);
467 Simd::maskstore(&C[M + 8], packed_mask_1, *C_accum_11);
468 Simd::maskstore(&C[2 * M], packed_mask_0, *C_accum_20);
469 Simd::maskstore(&C[2 * M + 8], packed_mask_1, *C_accum_21);
470 Simd::maskstore(&C[3 * M], packed_mask_0, *C_accum_30);
471 Simd::maskstore(&C[3 * M + 8], packed_mask_1, *C_accum_31);
472 }
473
474 inline void maskstore_accum_04(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
475 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
476 reg *C_accum_30, reg *C_accum_31, reg *C_accum_40,
477 reg *C_accum_41, __m256i packed_mask_0, __m256i packed_mask_1,
478 int M) {
479 Simd::maskstore(C, packed_mask_0, *C_accum_00);
480 Simd::maskstore(&C[8], packed_mask_1, *C_accum_01);
481 Simd::maskstore(&C[M], packed_mask_0, *C_accum_10);
482 Simd::maskstore(&C[M + 8], packed_mask_1, *C_accum_11);
483 Simd::maskstore(&C[2 * M], packed_mask_0, *C_accum_20);
484 Simd::maskstore(&C[2 * M + 8], packed_mask_1, *C_accum_21);
485 Simd::maskstore(&C[3 * M], packed_mask_0, *C_accum_30);
486 Simd::maskstore(&C[3 * M + 8], packed_mask_1, *C_accum_31);
487 Simd::maskstore(&C[4 * M], packed_mask_0, *C_accum_40);
488 Simd::maskstore(&C[4 * M + 8], packed_mask_1, *C_accum_41);
489 }
490
491 inline void maskstore_accum_05(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10,
492 reg *C_accum_11, reg *C_accum_20, reg *C_accum_21,
493 reg *C_accum_30, reg *C_accum_31, reg *C_accum_40,
494 reg *C_accum_41, reg *C_accum_50, reg *C_accum_51,
495 __m256i packed_mask_0, __m256i packed_mask_1, int M) {
496 Simd::maskstore(C, packed_mask_0, *C_accum_00);
497 Simd::maskstore(&C[8], packed_mask_1, *C_accum_01);
498 Simd::maskstore(&C[M], packed_mask_0, *C_accum_10);
499 Simd::maskstore(&C[M + 8], packed_mask_1, *C_accum_11);
500 Simd::maskstore(&C[2 * M], packed_mask_0, *C_accum_20);
501 Simd::maskstore(&C[2 * M + 8], packed_mask_1, *C_accum_21);
502 Simd::maskstore(&C[3 * M], packed_mask_0, *C_accum_30);
503 Simd::maskstore(&C[3 * M + 8], packed_mask_1, *C_accum_31);
504 Simd::maskstore(&C[4 * M], packed_mask_0, *C_accum_40);
505 Simd::maskstore(&C[4 * M + 8], packed_mask_1, *C_accum_41);
506 Simd::maskstore(&C[5 * M], packed_mask_0, *C_accum_50);
507 Simd::maskstore(&C[5 * M + 8], packed_mask_1, *C_accum_51);
508 }
509
510 void kernel_16x6_load_accum(T *blockA_packed, T *blockB_packed, T *C, int mr, int nr, int kc,
511 int M) {
512 reg C_accum_00 = {};
513 reg C_accum_01 = {};
514 reg C_accum_10 = {};
515 reg C_accum_11 = {};
516 reg C_accum_20 = {};
517 reg C_accum_21 = {};
518 reg C_accum_30 = {};
519 reg C_accum_31 = {};
520 reg C_accum_40 = {};
521 reg C_accum_41 = {};
522 reg C_accum_50 = {};
523 reg C_accum_51 = {};
524
525 reg b_packFloat8 = {};
526 reg a0_packFloat8 = {};
527 reg a1_packFloat8 = {};
528 __m256i packed_mask_0 = {};
529 __m256i packed_mask_1 = {};
530
531 if (mr != 16) {
532 build_masks(&packed_mask_0, &packed_mask_1, mr);
533 switch (nr) {
534 case 1:
535 maskload_accum_00(C, &C_accum_00, &C_accum_01, packed_mask_0, packed_mask_1, M);
536 fma_loop_00(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &a0_packFloat8,
537 &a1_packFloat8, &b_packFloat8, kc);
538 maskstore_accum_00(C, &C_accum_00, &C_accum_01, packed_mask_0, packed_mask_1, M);
539 break;
540 case 2:
541 maskload_accum_01(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
542 packed_mask_0, packed_mask_1, M);
543 fma_loop_01(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
544 &C_accum_11, &a0_packFloat8, &a1_packFloat8, &b_packFloat8, kc);
545 maskstore_accum_01(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
546 packed_mask_0, packed_mask_1, M);
547 break;
548 case 3:
549 maskload_accum_02(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
550 &C_accum_20, &C_accum_21, packed_mask_0, packed_mask_1, M);
551 fma_loop_02(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
552 &C_accum_11, &C_accum_20, &C_accum_21, &a0_packFloat8, &a1_packFloat8,
553 &b_packFloat8, kc);
554 maskstore_accum_02(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
555 &C_accum_20, &C_accum_21, packed_mask_0, packed_mask_1, M);
556 break;
557 case 4:
558 maskload_accum_03(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
559 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31, packed_mask_0,
560 packed_mask_1, M);
561 fma_loop_03(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
562 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
563 &a0_packFloat8, &a1_packFloat8, &b_packFloat8, kc);
564 maskstore_accum_03(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
565 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
566 packed_mask_0, packed_mask_1, M);
567 break;
568 case 5:
569 maskload_accum_04(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
570 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40,
571 &C_accum_41, packed_mask_0, packed_mask_1, M);
572 fma_loop_04(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
573 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
574 &C_accum_40, &C_accum_41, &a0_packFloat8, &a1_packFloat8, &b_packFloat8,
575 kc);
576 maskstore_accum_04(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
577 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40,
578 &C_accum_41, packed_mask_0, packed_mask_1, M);
579 break;
580 case 6:
581 maskload_accum_05(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
582 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40,
583 &C_accum_41, &C_accum_50, &C_accum_51, packed_mask_0,
584 packed_mask_1, M);
585 fma_loop_05(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
586 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
587 &C_accum_40, &C_accum_41, &C_accum_50, &C_accum_51, &a0_packFloat8,
588 &a1_packFloat8, &b_packFloat8, kc);
589 maskstore_accum_05(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
590 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40,
591 &C_accum_41, &C_accum_50, &C_accum_51, packed_mask_0,
592 packed_mask_1, M);
593 break;
594 }
595 } else {
596 switch (nr) {
597 case 1:
598 load_accum_00(C, &C_accum_00, &C_accum_01, M);
599 fma_loop_00(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &a0_packFloat8,
600 &a1_packFloat8, &b_packFloat8, kc);
601 store_accum_00(C, &C_accum_00, &C_accum_01, M);
602 break;
603 case 2:
604 load_accum_01(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, M);
605 fma_loop_01(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
606 &C_accum_11, &a0_packFloat8, &a1_packFloat8, &b_packFloat8, kc);
607 store_accum_01(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, M);
608 break;
609 case 3:
610 load_accum_02(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
611 &C_accum_21, M);
612 fma_loop_02(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
613 &C_accum_11, &C_accum_20, &C_accum_21, &a0_packFloat8, &a1_packFloat8,
614 &b_packFloat8, kc);
615 store_accum_02(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
616 &C_accum_21, M);
617 break;
618 case 4:
619 load_accum_03(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
620 &C_accum_21, &C_accum_30, &C_accum_31, M);
621 fma_loop_03(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
622 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
623 &a0_packFloat8, &a1_packFloat8, &b_packFloat8, kc);
624 store_accum_03(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
625 &C_accum_21, &C_accum_30, &C_accum_31, M);
626 break;
627 case 5:
628 load_accum_04(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
629 &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40, &C_accum_41, M);
630 fma_loop_04(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
631 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
632 &C_accum_40, &C_accum_41, &a0_packFloat8, &a1_packFloat8, &b_packFloat8,
633 kc);
634 store_accum_04(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
635 &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40, &C_accum_41, M);
636
637 break;
638 case 6:
639 load_accum_05(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
640 &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40, &C_accum_41,
641 &C_accum_50, &C_accum_51, M);
642 fma_loop_05(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
643 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
644 &C_accum_40, &C_accum_41, &C_accum_50, &C_accum_51, &a0_packFloat8,
645 &a1_packFloat8, &b_packFloat8, kc);
646 store_accum_05(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
647 &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40, &C_accum_41,
648 &C_accum_50, &C_accum_51, M);
649 break;
650 }
651 }
652 }
653
654 void kernel_16x6_zero_init_accum(T *blockA_packed, T *blockB_packed, T *C, int mr, int nr,
655 int kc, int M) {
656 reg C_accum_00 = {};
657 reg C_accum_01 = {};
658 reg C_accum_10 = {};
659 reg C_accum_11 = {};
660 reg C_accum_20 = {};
661 reg C_accum_21 = {};
662 reg C_accum_30 = {};
663 reg C_accum_31 = {};
664 reg C_accum_40 = {};
665 reg C_accum_41 = {};
666 reg C_accum_50 = {};
667 reg C_accum_51 = {};
668
669 reg b_packFloat8 = {};
670 reg a0_packFloat8 = {};
671 reg a1_packFloat8 = {};
672 __m256i packed_mask_0 = {};
673 __m256i packed_mask_1 = {};
674
675 if (mr != 16) {
676 build_masks(&packed_mask_0, &packed_mask_1, mr);
677 switch (nr) {
678 case 1:
679 fma_loop_00(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &a0_packFloat8,
680 &a1_packFloat8, &b_packFloat8, kc);
681 maskstore_accum_00(C, &C_accum_00, &C_accum_01, packed_mask_0, packed_mask_1, M);
682 break;
683 case 2:
684 fma_loop_01(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
685 &C_accum_11, &a0_packFloat8, &a1_packFloat8, &b_packFloat8, kc);
686 maskstore_accum_01(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
687 packed_mask_0, packed_mask_1, M);
688 break;
689 case 3:
690 fma_loop_02(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
691 &C_accum_11, &C_accum_20, &C_accum_21, &a0_packFloat8, &a1_packFloat8,
692 &b_packFloat8, kc);
693 maskstore_accum_02(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
694 &C_accum_20, &C_accum_21, packed_mask_0, packed_mask_1, M);
695 break;
696 case 4:
697 fma_loop_03(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
698 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
699 &a0_packFloat8, &a1_packFloat8, &b_packFloat8, kc);
700 maskstore_accum_03(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
701 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
702 packed_mask_0, packed_mask_1, M);
703 break;
704 case 5:
705 fma_loop_04(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
706 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
707 &C_accum_40, &C_accum_41, &a0_packFloat8, &a1_packFloat8, &b_packFloat8,
708 kc);
709 maskstore_accum_04(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
710 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40,
711 &C_accum_41, packed_mask_0, packed_mask_1, M);
712 break;
713 case 6:
714 fma_loop_05(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
715 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
716 &C_accum_40, &C_accum_41, &C_accum_50, &C_accum_51, &a0_packFloat8,
717 &a1_packFloat8, &b_packFloat8, kc);
718 maskstore_accum_05(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11,
719 &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40,
720 &C_accum_41, &C_accum_50, &C_accum_51, packed_mask_0,
721 packed_mask_1, M);
722 break;
723 }
724 } else {
725 switch (nr) {
726 case 1:
727 fma_loop_00(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &a0_packFloat8,
728 &a1_packFloat8, &b_packFloat8, kc);
729 store_accum_00(C, &C_accum_00, &C_accum_01, M);
730 break;
731 case 2:
732 fma_loop_01(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
733 &C_accum_11, &a0_packFloat8, &a1_packFloat8, &b_packFloat8, kc);
734 store_accum_01(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, M);
735 break;
736 case 3:
737 fma_loop_02(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
738 &C_accum_11, &C_accum_20, &C_accum_21, &a0_packFloat8, &a1_packFloat8,
739 &b_packFloat8, kc);
740 store_accum_02(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
741 &C_accum_21, M);
742 break;
743 case 4:
744 fma_loop_03(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
745 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
746 &a0_packFloat8, &a1_packFloat8, &b_packFloat8, kc);
747 store_accum_03(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
748 &C_accum_21, &C_accum_30, &C_accum_31, M);
749 break;
750 case 5:
751 fma_loop_04(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
752 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
753 &C_accum_40, &C_accum_41, &a0_packFloat8, &a1_packFloat8, &b_packFloat8,
754 kc);
755 store_accum_04(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
756 &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40, &C_accum_41, M);
757
758 break;
759 case 6:
760 fma_loop_05(blockA_packed, blockB_packed, &C_accum_00, &C_accum_01, &C_accum_10,
761 &C_accum_11, &C_accum_20, &C_accum_21, &C_accum_30, &C_accum_31,
762 &C_accum_40, &C_accum_41, &C_accum_50, &C_accum_51, &a0_packFloat8,
763 &a1_packFloat8, &b_packFloat8, kc);
764 store_accum_05(C, &C_accum_00, &C_accum_01, &C_accum_10, &C_accum_11, &C_accum_20,
765 &C_accum_21, &C_accum_30, &C_accum_31, &C_accum_40, &C_accum_41,
766 &C_accum_50, &C_accum_51, M);
767 break;
768 }
769 }
770 }
771
772#ifndef NTHREADS
773# define NTHREADS 8
774#endif
775
776#define MC (16 * (40 / NTHREADS) * NTHREADS)
777#define NC (6 * (800 / NTHREADS) * NTHREADS)
778#define KC 500
779
780#ifndef OMP_SCHEDULE
781# define OMP_SCHEDULE auto
782#endif
783#define _min(x, y) ((x) < (y) ? (x) : (y))
784#define PRAGMA_OMP_PARALLEL_FOR \
785 _Pragma("omp parallel for schedule(OMP_SCHEDULE) num_threads(NTHREADS)")
786
787 static T blockA_packed[MC * KC] __attribute__((aligned(64)));
788 static T blockB_packed[NC * KC] __attribute__((aligned(64)));
789
790 void pack_panelB(T *B, T *blockB_packed, int nr, int kc, int K) {
791 for (int p = 0; p < kc; p++) {
792 for (int j = 0; j < nr; j++) {
793 *blockB_packed++ = B[j * K + p];
794 }
795 for (int j = nr; j < 6; j++) {
796 *blockB_packed++ = 0;
797 }
798 }
799 }
800
801 void pack_blockB(T *B, T *blockB_packed, int nc, int kc, int K) {
802#pragma omp for schedule(dynamic)
803 for (int j = 0; j < nc; j += 6) {
804 int nr = _min(6, nc - j);
805 pack_panelB(&B[j * K], &blockB_packed[j * kc], nr, kc, K);
806 }
807 }
808
809 void pack_panelA(T *A, T *blockA_packed, int mr, int kc, int M) {
810 for (int p = 0; p < kc; p++) {
811 for (int i = 0; i < mr; i++) {
812 *blockA_packed++ = A[p * M + i];
813 }
814 for (int i = mr; i < 16; i++) {
815 *blockA_packed++ = 0;
816 }
817 }
818 }
819
820 void pack_blockA(T *A, T *blockA_packed, int mc, int kc, int M) {
822 for (int i = 0; i < mc; i += 16) {
823 int mr = _min(16, mc - i);
824 pack_panelA(&A[i], &blockA_packed[i * kc], mr, kc, M);
825 }
826 }
827
828 void matmul(T *A, T *B, T *C, int M, int N, int K) {
829 for (int j = 0; j < N; j += NC) {
830 int nc = _min(NC, N - j);
831 int kc = _min(KC, K);
832
833 pack_blockB(&B[j * K], blockB_packed, nc, kc, K);
834
835 for (int i = 0; i < M; i += MC) {
836 int mc = _min(MC, M - i);
837
838 pack_blockA(&A[i], blockA_packed, mc, kc, M);
839
841 for (int jr = 0; jr < nc; jr += 6) {
842 int nr = _min(6, nc - jr);
843 for (int ir = 0; ir < mc; ir += 16) {
844 int mr = _min(16, mc - ir);
845 kernel_16x6_zero_init_accum(&blockA_packed[ir * kc],
846 &blockB_packed[jr * kc],
847 &C[(j + jr) * M + (i + ir)], mr, nr, kc, M);
848 }
849 }
850 }
851 for (int p = kc; p < K; p += KC) {
852 int cur_kc = _min(KC, K - p);
853 pack_blockB(&B[j * K + p], blockB_packed, nc, cur_kc, K);
854
855 for (int i = 0; i < M; i += MC) {
856 int mc = _min(MC, M - i);
857
858 pack_blockA(&A[i + p * M], blockA_packed, mc, cur_kc, M);
859
861 for (int jr = 0; jr < nc; jr += 6) {
862 int nr = _min(6, nc - jr);
863 for (int ir = 0; ir < mc; ir += 16) {
864 int mr = _min(16, mc - ir);
865 kernel_16x6_load_accum(&blockA_packed[ir * cur_kc],
866 &blockB_packed[jr * cur_kc],
867 &C[(j + jr) * M + (i + ir)], mr, nr, cur_kc, M);
868 }
869 }
870 }
871 }
872 }
873 }
874};
875} // namespace tensorium
876
877namespace tensorium {
878template <typename T> T GemmKernelBigger<T>::blockA_packed[MC * KC] __attribute__((aligned(64)));
879
880template <typename T> T GemmKernelBigger<T>::blockB_packed[NC * KC] __attribute__((aligned(64)));
881} // namespace tensorium
#define NC
Definition GemmKernel_bigger.hpp:777
#define _min(x, y)
Definition GemmKernel_bigger.hpp:783
#define PRAGMA_OMP_PARALLEL_FOR
Definition GemmKernel_bigger.hpp:784
#define MC
Definition GemmKernel_bigger.hpp:776
#define KC
Definition GemmKernel_bigger.hpp:778
Definition GemmKernel_bigger.hpp:16
static constexpr int SimdWidth
Definition GemmKernel_bigger.hpp:20
void pack_panelB(T *B, T *blockB_packed, int nr, int kc, int K)
Definition GemmKernel_bigger.hpp:790
static constexpr int TileRows
Definition GemmKernel_bigger.hpp:21
void fma_loop_04(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc)
Definition GemmKernel_bigger.hpp:125
void maskstore_accum_00(T *C, reg *C_accum_00, reg *C_accum_01, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:434
void maskload_accum_04(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:260
void store_accum_05(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, reg *C_accum_50, reg *C_accum_51, int M)
Definition GemmKernel_bigger.hpp:416
void store_accum_03(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, int M)
Definition GemmKernel_bigger.hpp:388
void maskload_accum_00(T *C, reg *C_accum_00, reg *C_accum_01, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:220
void fma_loop_01(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc)
Definition GemmKernel_bigger.hpp:49
void matmul(T *A, T *B, T *C, int M, int N, int K)
Definition GemmKernel_bigger.hpp:828
void pack_panelA(T *A, T *blockA_packed, int mr, int kc, int M)
Definition GemmKernel_bigger.hpp:809
static constexpr int TileCols
Definition GemmKernel_bigger.hpp:22
void maskstore_accum_05(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, reg *C_accum_50, reg *C_accum_51, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:491
void maskload_accum_01(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:226
void maskload_accum_03(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:246
void fma_loop_02(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc)
Definition GemmKernel_bigger.hpp:70
void fma_loop_05(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, reg *C_accum_50, reg *C_accum_51, reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc)
Definition GemmKernel_bigger.hpp:159
void maskload_accum_05(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, reg *C_accum_50, reg *C_accum_51, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:277
void load_accum_04(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, int M)
Definition GemmKernel_bigger.hpp:332
void store_accum_04(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, int M)
Definition GemmKernel_bigger.hpp:401
void fma_loop_00(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01, reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc)
Definition GemmKernel_bigger.hpp:33
void kernel_16x6_zero_init_accum(T *blockA_packed, T *blockB_packed, T *C, int mr, int nr, int kc, int M)
Definition GemmKernel_bigger.hpp:654
void load_accum_01(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, int M)
Definition GemmKernel_bigger.hpp:301
void maskstore_accum_02(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:449
static void build_masks(__m256i *packed_mask_0, __m256i *packed_mask_1, int mr)
Definition GemmKernel_bigger.hpp:198
void maskstore_accum_03(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:460
void load_accum_00(T *C, reg *C_accum_00, reg *C_accum_01, int M)
Definition GemmKernel_bigger.hpp:296
void pack_blockA(T *A, T *blockA_packed, int mc, int kc, int M)
Definition GemmKernel_bigger.hpp:820
void store_accum_02(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, int M)
Definition GemmKernel_bigger.hpp:378
void maskstore_accum_04(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:474
void maskstore_accum_01(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:440
void maskload_accum_02(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, __m256i packed_mask_0, __m256i packed_mask_1, int M)
Definition GemmKernel_bigger.hpp:235
static constexpr int BlockRows
Definition GemmKernel_bigger.hpp:26
void store_accum_00(T *C, reg *C_accum_00, reg *C_accum_01, int M)
Definition GemmKernel_bigger.hpp:365
void load_accum_03(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, int M)
Definition GemmKernel_bigger.hpp:319
static constexpr int BlockDepth
Definition GemmKernel_bigger.hpp:25
static T blockA_packed[MC *KC] __attribute__((aligned(64)))
typename Simd::reg reg
Definition GemmKernel_bigger.hpp:19
void kernel_16x6_load_accum(T *blockA_packed, T *blockB_packed, T *C, int mr, int nr, int kc, int M)
Definition GemmKernel_bigger.hpp:510
void load_accum_02(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, int M)
Definition GemmKernel_bigger.hpp:309
static constexpr int BlockCols
Definition GemmKernel_bigger.hpp:27
static constexpr int NThreads
Definition GemmKernel_bigger.hpp:23
void load_accum_05(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *C_accum_40, reg *C_accum_41, reg *C_accum_50, reg *C_accum_51, int M)
Definition GemmKernel_bigger.hpp:347
void fma_loop_03(T *blockA_packed, T *blockB_packed, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, reg *C_accum_20, reg *C_accum_21, reg *C_accum_30, reg *C_accum_31, reg *a0_packFloat8, reg *a1_packFloat8, reg *b_packFloat8, int kc)
Definition GemmKernel_bigger.hpp:95
static int8_t mask[32] __attribute__((aligned(64)))
void store_accum_01(T *C, reg *C_accum_00, reg *C_accum_01, reg *C_accum_10, reg *C_accum_11, int M)
Definition GemmKernel_bigger.hpp:370
void pack_blockB(T *B, T *blockB_packed, int nc, int kc, int K)
Definition GemmKernel_bigger.hpp:801
static T blockB_packed[NC *KC] __attribute__((aligned(64)))
Definition Derivate.hpp:24
T GemmKernelBigger< T >::blockA_packed[MC *KC] __attribute__((aligned(64)))
Definition SIMD.hpp:177