Tensorium
Loading...
Searching...
No Matches
MatrixKernel.hpp
Go to the documentation of this file.
1#pragma once
2
3#include "../Matrix.hpp"
4
5namespace tensorium {
6
17template <typename K> class MatrixKernel : public Matrix<K, true> {
18 public:
19 using Matrix<K, true>::rows;
20 using Matrix<K, true>::cols;
21 using Matrix<K, true>::data;
22 using Matrix<K, true>::operator();
23
25 using reg = typename Simd::reg;
30 MatrixKernel(const Matrix<K, true> &m) : Matrix<K, true>(m) {}
35 MatrixKernel(const Matrix<K, false> &m) : Matrix<K, true>(m.rows, m.cols) {
36 for (size_t i = 0; i < m.rows; ++i)
37 for (size_t j = 0; j < m.cols; ++j)
38 (*this)(i, j) = m(i, j);
39 }
40
44 MatrixKernel(size_t r, size_t c) : Matrix<K, true>(r, c) {}
50 inline Matrix<K> mul_mat2x2(const MatrixKernel<K> &B) const {
52 using reg = typename Simd::reg;
53
54 Matrix<K> C(2, 2);
55
56 reg b_col0 = Simd::loadu(&B.data[0]);
57 reg b_col1 = Simd::loadu(&B.data[2]);
58
59 K b00 = Simd::extract(b_col0, 0);
60 K b10 = Simd::extract(b_col0, 1);
61 K b01 = Simd::extract(b_col1, 0);
62 K b11 = Simd::extract(b_col1, 1);
63
64 C(0, 0) = (*this)(0, 0) * b00 + (*this)(0, 1) * b10;
65 C(1, 0) = (*this)(1, 0) * b00 + (*this)(1, 1) * b10;
66 C(0, 1) = (*this)(0, 0) * b01 + (*this)(0, 1) * b11;
67 C(1, 1) = (*this)(1, 0) * b01 + (*this)(1, 1) * b11;
68
69 return C;
70 }
76 inline Matrix<K> mul_mat3x3(const MatrixKernel<K> &B) const {
78 using reg = typename Simd::reg;
79
80 Matrix<K> result(3, 3);
81
82 for (int i = 0; i < 3; ++i) {
83 reg ai0 = Simd::set1((*this)(i, 0));
84 reg ai1 = Simd::set1((*this)(i, 1));
85 reg ai2 = Simd::set1((*this)(i, 2));
86
87 alignas(32) K brow0[4] = {B(0, 0), B(0, 1), B(0, 2), K(0)};
88 alignas(32) K brow1[4] = {B(1, 0), B(1, 1), B(1, 2), K(0)};
89 alignas(32) K brow2[4] = {B(2, 0), B(2, 1), B(2, 2), K(0)};
90
91 reg b0 = Simd::loadu(brow0);
92 reg b1 = Simd::loadu(brow1);
93 reg b2 = Simd::loadu(brow2);
94
95 reg acc = Simd::mul(ai0, b0);
96 acc = Simd::fmadd(ai1, b1, acc);
97 acc = Simd::fmadd(ai2, b2, acc);
98
99 result(i, 0) = Simd::extract(acc, 0);
100 result(i, 1) = Simd::extract(acc, 1);
101 result(i, 2) = Simd::extract(acc, 2);
102 }
103
104 return result;
105 }
106
112 inline Matrix<K> mul_mat4x4(const MatrixKernel<K> &B) const {
114 using reg = typename Simd::reg;
115
116 Matrix<K> result(4, 4);
117
118 reg brow0 = Simd::loadu(&B.data[0 * 4 + 0]);
119 reg brow1 = Simd::loadu(&B.data[1 * 4 + 0]);
120 reg brow2 = Simd::loadu(&B.data[2 * 4 + 0]);
121 reg brow3 = Simd::loadu(&B.data[3 * 4 + 0]);
122
123 {
124 const K *a = &data[0 * 4];
125 reg a0 = Simd::set1(a[0]);
126 reg a1 = Simd::set1(a[1]);
127 reg a2 = Simd::set1(a[2]);
128 reg a3 = Simd::set1(a[3]);
129
130 reg acc0 = Simd::mul(a0, brow0);
131 acc0 = Simd::fmadd(a1, brow1, acc0);
132 acc0 = Simd::fmadd(a2, brow2, acc0);
133 acc0 = Simd::fmadd(a3, brow3, acc0);
134
135 Simd::storeu(&result.data[0 * 4 + 0], acc0);
136 }
137
138 {
139 const K *a = &data[1 * 4];
140 reg a0 = Simd::set1(a[0]);
141 reg a1 = Simd::set1(a[1]);
142 reg a2 = Simd::set1(a[2]);
143 reg a3 = Simd::set1(a[3]);
144
145 reg acc1 = Simd::mul(a0, brow0);
146 acc1 = Simd::fmadd(a1, brow1, acc1);
147 acc1 = Simd::fmadd(a2, brow2, acc1);
148 acc1 = Simd::fmadd(a3, brow3, acc1);
149
150 Simd::storeu(&result.data[1 * 4 + 0], acc1);
151 }
152
153 {
154 const K *a = &data[2 * 4];
155 reg a0 = Simd::set1(a[0]);
156 reg a1 = Simd::set1(a[1]);
157 reg a2 = Simd::set1(a[2]);
158 reg a3 = Simd::set1(a[3]);
159
160 reg acc2 = Simd::mul(a0, brow0);
161 acc2 = Simd::fmadd(a1, brow1, acc2);
162 acc2 = Simd::fmadd(a2, brow2, acc2);
163 acc2 = Simd::fmadd(a3, brow3, acc2);
164
165 Simd::storeu(&result.data[2 * 4 + 0], acc2);
166 }
167
168 {
169 const K *a = &data[3 * 4];
170 reg a0 = Simd::set1(a[0]);
171 reg a1 = Simd::set1(a[1]);
172 reg a2 = Simd::set1(a[2]);
173 reg a3 = Simd::set1(a[3]);
174
175 reg acc3 = Simd::mul(a0, brow0);
176 acc3 = Simd::fmadd(a1, brow1, acc3);
177 acc3 = Simd::fmadd(a2, brow2, acc3);
178 acc3 = Simd::fmadd(a3, brow3, acc3);
179
180 Simd::storeu(&result.data[3 * 4 + 0], acc3);
181 }
182
183 return result;
184 }
185
191 inline Matrix<K> mul_mat8x8(const MatrixKernel<K> &B) const {
193 using reg = typename Simd::reg;
194
195 Matrix<K> result(8, 8);
196
197 reg col[8];
198 for (int j = 0; j < 8; ++j)
199 col[j] = Simd::loadu(&B.data[j * 8]);
200
201 for (int i = 0; i < 8; ++i) {
202 const K *a = &data[i * 8];
203 reg a0 = Simd::set1(a[0]);
204 reg a1 = Simd::set1(a[1]);
205 reg a2 = Simd::set1(a[2]);
206 reg a3 = Simd::set1(a[3]);
207 reg a4 = Simd::set1(a[4]);
208 reg a5 = Simd::set1(a[5]);
209 reg a6 = Simd::set1(a[6]);
210 reg a7 = Simd::set1(a[7]);
211
212 reg acc = Simd::mul(a0, col[0]);
213 acc = Simd::fmadd(a1, col[1], acc);
214 acc = Simd::fmadd(a2, col[2], acc);
215 acc = Simd::fmadd(a3, col[3], acc);
216 acc = Simd::fmadd(a4, col[4], acc);
217 acc = Simd::fmadd(a5, col[5], acc);
218 acc = Simd::fmadd(a6, col[6], acc);
219 acc = Simd::fmadd(a7, col[7], acc);
220
221 Simd::storeu(&result.data[i * 8], acc);
222 }
223 return result;
224 }
225
232 inline Matrix<K> mul_mat16x16(const MatrixKernel<K> &B) const {
234 using reg = typename Simd::reg;
235
236 Matrix<K> result(16, 16);
237
238 reg row_lo[16], row_hi[16];
239 for (int k = 0; k < 16; ++k) {
240 row_lo[k] = Simd::loadu(&B.data[k * 16 + 0]);
241 row_hi[k] = Simd::loadu(&B.data[k * 16 + 8]);
242 }
243
244 for (int i = 0; i < 16; ++i) {
245 const K *a = &data[i * 16];
246
247 reg acc_lo = Simd::mul(Simd::set1(a[0]), row_lo[0]);
248 reg acc_hi = Simd::mul(Simd::set1(a[0]), row_hi[0]);
249
250 for (int k = 1; k < 16; ++k) {
251 reg ak = Simd::set1(a[k]);
252 acc_lo = Simd::fmadd(ak, row_lo[k], acc_lo);
253 acc_hi = Simd::fmadd(ak, row_hi[k], acc_hi);
254 }
255
256 Simd::storeu(&result.data[i * 16 + 0], acc_lo);
257 Simd::storeu(&result.data[i * 16 + 8], acc_hi);
258 }
259
260 return result;
261 }
268 inline Matrix<K> mul_mat32x32(const MatrixKernel<K> &B) const {
270 using reg = typename Simd::reg;
271
272 Matrix<K> result(32, 32);
273
274 reg brow[32][2];
275#pragma unroll(2)
276 for (int k = 0; k < 32; ++k) {
277 brow[k][0] = Simd::loadu(&B.data[k * 32 + 0]);
278 brow[k][1] = Simd::loadu(&B.data[k * 32 + 16]);
279 }
280#pragma unroll(2)
281 for (int i = 0; i < 32; ++i) {
282 const K *a = &data[i * 32];
283 reg acc0 = Simd::mul(Simd::set1(a[0]), brow[0][0]);
284 reg acc1 = Simd::mul(Simd::set1(a[0]), brow[0][1]);
285
286 for (int k = 1; k < 32; ++k) {
287 reg ak = Simd::set1(a[k]);
288 acc0 = Simd::fmadd(ak, brow[k][0], acc0);
289 acc1 = Simd::fmadd(ak, brow[k][1], acc1);
290 }
291
292 Simd::storeu(&result.data[i * 32 + 0], acc0);
293 Simd::storeu(&result.data[i * 32 + 16], acc1);
294 }
295
296 return result;
297 }
298
306 inline Matrix<K> mul_mat64x64(const MatrixKernel<K> &B) const {
308 using reg = typename Simd::reg;
309
310 Matrix<K> result(64, 64);
311
312 reg brow_lo[64], brow_hi[64], brow_32[64], brow_48[64];
313
314 for (int k = 0; k < 64; ++k) {
315 const K *b = &B.data[k * 64];
316 brow_lo[k] = Simd::loadu(b + 0);
317 brow_hi[k] = Simd::loadu(b + 8);
318 brow_32[k] = Simd::loadu(b + 16);
319 brow_48[k] = Simd::loadu(b + 24);
320 }
321
322 for (int i = 0; i < 64; ++i) {
323 const K *a = &data[i * 64];
324
325 reg acc0 = Simd::mul(Simd::set1(a[0]), brow_lo[0]);
326 reg acc1 = Simd::mul(Simd::set1(a[0]), brow_hi[0]);
327 reg acc2 = Simd::mul(Simd::set1(a[0]), brow_32[0]);
328 reg acc3 = Simd::mul(Simd::set1(a[0]), brow_48[0]);
329
330 for (int k = 1; k < 64; ++k) {
331 reg ak = Simd::set1(a[k]);
332 acc0 = Simd::fmadd(ak, brow_lo[k], acc0);
333 acc1 = Simd::fmadd(ak, brow_hi[k], acc1);
334 acc2 = Simd::fmadd(ak, brow_32[k], acc2);
335 acc3 = Simd::fmadd(ak, brow_48[k], acc3);
336 }
337
338 K *r = &result.data[i * 64];
339 Simd::storeu(r + 0, acc0);
340 Simd::storeu(r + 8, acc1);
341 Simd::storeu(r + 16, acc2);
342 Simd::storeu(r + 24, acc3);
343 }
344
345 return result;
346 }
347};
348} // namespace tensorium
MatrixKernel provides specialized SIMD-accelerated matrix multiplication routines for statically-size...
Definition MatrixKernel.hpp:17
Matrix< K > mul_mat8x8(const MatrixKernel< K > &B) const
Multiply two 8×8 matrices using SIMD.
Definition MatrixKernel.hpp:191
Matrix< K > mul_mat16x16(const MatrixKernel< K > &B) const
Multiply two 16×16 matrices using SIMD with FMADD accumulation. This function splits each row into tw...
Definition MatrixKernel.hpp:232
MatrixKernel(const Matrix< K, false > &m)
Construct a MatrixKernel from a row-major matrix by copying elements.
Definition MatrixKernel.hpp:35
Matrix< K > mul_mat4x4(const MatrixKernel< K > &B) const
Multiply two 4×4 matrices using SIMD.
Definition MatrixKernel.hpp:112
Matrix< K > mul_mat2x2(const MatrixKernel< K > &B) const
Multiply two 2×2 matrices using SIMD.
Definition MatrixKernel.hpp:50
Matrix< K > mul_mat3x3(const MatrixKernel< K > &B) const
Multiply two 3×3 matrices using SIMD.
Definition MatrixKernel.hpp:76
MatrixKernel(size_t r, size_t c)
Construct an empty column-major matrix kernel of size (r × c).
Definition MatrixKernel.hpp:44
Matrix< K > mul_mat64x64(const MatrixKernel< K > &B) const
Multiply two 64×64 matrices using SIMD. Each row is split into 4 SIMD registers (4×16 elements)....
Definition MatrixKernel.hpp:306
MatrixKernel(const Matrix< K, true > &m)
Construct a MatrixKernel from a column-major matrix.
Definition MatrixKernel.hpp:30
typename Simd::reg reg
Definition MatrixKernel.hpp:25
Matrix< K > mul_mat32x32(const MatrixKernel< K > &B) const
Multiply two 32×32 matrices using SIMD. Each row is split into two registers (16 elements each).
Definition MatrixKernel.hpp:268
High-performance aligned matrix class with SIMD support.
Definition Matrix.hpp:27
size_t rows
Definition Matrix.hpp:29
aligned_vector< K > data
Definition Matrix.hpp:30
size_t cols
Definition Matrix.hpp:29
Definition Derivate.hpp:24
Definition SIMD.hpp:177