Tensorium
Loading...
Searching...
No Matches
Matrix.hpp
Go to the documentation of this file.
1#pragma once
2
5#include "../SIMD/CPU_id.hpp"
6#include "../SIMD/SIMD.hpp"
8#include "Vector.hpp"
9#include <cassert>
10#include <cmath>
11#include <immintrin.h>
12#include <iostream>
13#include <vector>
14
15namespace tensorium {
27template <typename K, bool RowMajor = false> class Matrix {
28 public:
29 size_t rows, cols;
31 size_t block_size;
36 Matrix(size_t r, size_t c)
37 : rows(r),
38 cols(c),
39 data(r * c, K()),
41
42 inline size_t index(size_t i, size_t j) const {
43 if constexpr (RowMajor)
44 return i * cols + j;
45 else
46 return j * rows + i;
47 }
49 using reg = typename Simd::reg;
50 size_t simd_width = Simd::width;
52 size_t size() const { return rows * cols; }
55 K &operator()(size_t i, size_t j) { return data[index(i, j)]; }
56
57 const K &operator()(size_t i, size_t j) const { return data[index(i, j)]; }
58
60 void print() const {
61 for (size_t i = 0; i < rows; ++i) {
62 std::cout << "[ ";
63 for (size_t j = 0; j < cols; ++j)
64 std::cout << operator()(i, j) << " ";
65 std::cout << "]\n";
66 }
67 }
69 void swap_rows(size_t i, size_t j) {
70 assert(i < rows && j < rows);
71 for (size_t k = 0; k < cols; ++k) {
72 MathsUtils::_swap((*this)(i, k), (*this)(j, k));
73 }
74 }
82 template <typename T> Vector<T> operator*(const Vector<T> &v) const {
83 assert(cols == v.size() && "Matrix-Vector size mismatch");
84 Vector<T> result(rows);
85 for (auto &x : result)
86 x = T(0);
87
88 for (size_t i = 0; i < rows; ++i) {
89 for (size_t j = 0; j < cols; ++j) {
90 result[i] += (*this)(i, j) * v[j];
91 }
92 }
93 return result;
94 }
96 inline void add(const Matrix &m) {
97 if (rows != m.rows || cols != m.cols)
98 throw std::invalid_argument("Matrix sizes do not match");
99
101 using reg = typename Simd::reg;
102 const size_t simd_width = Simd::width;
103
104 size_t n = size();
105 size_t i = 0;
106
107 _mm_prefetch((const char *)&m.data[0], _MM_HINT_T0);
108
109 for (; i + 2 * simd_width - 1 < n; i += 2 * simd_width) {
110 reg a0 = Simd::load(&data[i]);
111 reg b0 = Simd::load(&m.data[i]);
112 a0 = Simd::add(a0, b0);
113 Simd::store(&data[i], a0);
114
115 reg a1 = Simd::load(&data[i + simd_width]);
116 reg b1 = Simd::load(&m.data[i + simd_width]);
117 a1 = Simd::add(a1, b1);
118 Simd::store(&data[i + simd_width], a1);
119 }
120
121 for (; i < n; ++i)
122 data[i] += m.data[i];
123 }
125 inline void sub(const Matrix &m) {
126 if (rows != m.rows || cols != m.cols)
127 throw std::invalid_argument("Matrix sizes do not match");
129 using reg = typename Simd::reg;
130 const size_t simd_width = Simd::width;
131
132 size_t n = size();
133 size_t i = 0;
134
135 _mm_prefetch((const char *)&m.data[0], _MM_HINT_T0);
136 for (; i + 15 < n; i += 16) {
137 reg a0 = Simd::load(&data[i]);
138 reg b0 = Simd::load(&m.data[i]);
139 a0 = Simd::sub(a0, b0);
140 Simd::store(&data[i], a0);
141
142 reg a1 = Simd::load(&data[i + simd_width]);
143 reg b1 = Simd::load(&m.data[i + simd_width]);
144 a1 = Simd::sub(a1, b1);
145 Simd::store(&data[i + simd_width], a1);
146 }
147 for (; i < size(); ++i) {
148 data[i] -= m.data[i];
149 }
150 }
152 inline void scl(K a) {
153 size_t n = size();
154 size_t i = 0;
156 using reg = typename Simd::reg;
157 const size_t simd_width = Simd::width;
158 _mm_prefetch((const char *)&data[0], _MM_HINT_T0);
159 reg scalar = Simd::set1(a);
160
161 for (; i + 15 < n; i += 16) {
162 reg v0 = Simd::load(&data[i]);
163 v0 = Simd::mul(v0, scalar);
164 Simd::store(&data[i], v0);
165
166 reg v1 = Simd::load(&data[i + simd_width]);
167 v1 = Simd::mul(v1, scalar);
168 Simd::store(&data[i + simd_width], v1);
169 }
170
171 for (; i < n; ++i)
172 data[i] *= a;
173 }
174
176 inline void lerp(const Matrix<K> &A, const Matrix<K> &B, K alpha) {
177 if (A.rows != B.rows || A.cols != B.cols || rows != A.rows || cols != A.cols)
178 throw std::invalid_argument("Matrix size mismatch for lerp");
179
181 using reg = typename Simd::reg;
182 const size_t simd_width = Simd::width;
183
184 size_t n = size();
185 size_t i = 0;
186
187 reg alpha_vec = Simd::set1(alpha);
188 reg one_minus_alpha_vec = Simd::set1(K(1) - alpha);
189
190 for (; i + 2 * simd_width - 1 < n; i += 2 * simd_width) {
191 reg a0 = Simd::load(&A.data[i]);
192 reg b0 = Simd::load(&B.data[i]);
193 reg r0 = Simd::fmadd(one_minus_alpha_vec, a0, Simd::mul(alpha_vec, b0));
194 Simd::store(&data[i], r0);
195
196 reg a1 = Simd::load(&A.data[i + simd_width]);
197 reg b1 = Simd::load(&B.data[i + simd_width]);
198 reg r1 = Simd::fmadd(one_minus_alpha_vec, a1, Simd::mul(alpha_vec, b1));
199 Simd::store(&data[i + simd_width], r1);
200 }
201
202 for (; i < n; ++i) {
203 data[i] = (K(1) - alpha) * A.data[i] + alpha * B.data[i];
204 }
205 }
212 inline Matrix _mul_mat(const Matrix<K> &mat) const {
213 if (cols != mat.rows)
214 throw std::invalid_argument("Matrix dimensions do not match for multiplication");
215
216 Matrix<K> result(rows, mat.cols);
217
218 const K *A = data.data(); // Already column-major (this)
219 const K *B = mat.data.data(); // Already column-major (rhs)
220 K *C = result.data.data(); // Output (also column-major)
221
223 kernel.matmul(const_cast<K *>(A), const_cast<K *>(B), C,
224 static_cast<int>(rows), // M
225 static_cast<int>(mat.cols), // N
226 static_cast<int>(cols) // K
227 );
228
229 return result;
230 }
236 template <typename T> inline Vector<T> mul_vec(const Vector<T> &x) const {
238 using reg = typename Simd::reg;
239 constexpr size_t W = Simd::width;
240
241 assert(cols == x.size());
242
243 Vector<T> result(rows, T(0));
244
245 alignas(64) T buffer[W];
246
247 for (size_t i = 0; i < rows; ++i) {
248 reg acc = Simd::zero();
249 size_t j = 0;
250
251 for (; j + W <= cols; j += W) {
252 for (size_t w = 0; w < W; ++w)
253 buffer[w] = (*this)(i, j + w);
254
255 reg A_vec = Simd::load(buffer);
256 reg x_vec = Simd::load(&x[j]);
257 acc = Simd::fmadd(A_vec, x_vec, acc);
258 }
259
260 T sum = Simd::horizontal_add(acc);
261
262 for (; j < cols; ++j)
263 sum += (*this)(i, j) * x[j];
264
265 result[i] = sum;
266 }
267
268 return result;
269 }
270
272 inline Matrix<K> transpose() const {
273 Matrix<K> result(cols, rows);
274
275 for (size_t i = 0; i < rows; ++i)
276 for (size_t j = 0; j < cols; ++j)
277 result(j, i) = (*this)(i, j);
278
279 return result;
280 }
281
283 inline Matrix<K> trace() const {
284 if (rows != cols) {
285 throw std::invalid_argument("Matrix is not square");
286 }
287
288 Matrix<K> result(1, 1);
289 result(0, 0) = K(0);
290
291 for (size_t i = 0; i < rows; ++i) {
292 result(0, 0) += operator()(i, i);
293 }
294
295 return result;
296 }
303 inline Matrix<K> inverse() const {
304 if (rows != cols)
305 throw std::invalid_argument("Matrix must be square");
306
307 const auto n = rows;
308 Matrix<K> M(n, n);
309 Matrix<K> Inv(n, n);
310
311 for (auto i = decltype(n)(0); i < n; ++i) {
312 for (auto j = decltype(n)(0); j < n; ++j) {
313 M(i, j) = operator()(i, j);
314 Inv(i, j) = (i == j) ? K(1) : K(0);
315 }
316 }
317
319 for (auto i = decltype(n)(0); i < n; ++i) {
320 auto piv = i;
321 auto maxv = MathsUtils::_abs(M(i, i));
322 for (auto r = i + 1; r < n; ++r) {
323 auto v = MathsUtils::_abs(M(r, i));
324 if (v > maxv) {
325 maxv = v;
326 piv = r;
327 }
328 }
329 if (maxv < static_cast<K>(1e-6))
330 throw std::runtime_error("Matrix is singular or nearly singular.");
331
332 if (piv != i) {
333 M.swap_rows(i, piv);
334 Inv.swap_rows(i, piv);
335 }
336
337 auto diag = M(i, i);
338 auto diag_inv = K(1) / diag;
339 for (auto j = 0u; j < n; ++j) {
340 M(i, j) *= diag_inv;
341 Inv(i, j) *= diag_inv;
342 }
343
344#pragma omp parallel for schedule(dynamic, UNROLL)
345 for (auto j = 0u; j < n; ++j) {
346 if (j != i) {
347 auto f = M(j, i);
348 for (auto k = 0u; k < n; ++k) {
349 M(j, k) -= f * M(i, k);
350 Inv(j, k) -= f * Inv(i, k);
351 }
352 }
353 }
354 }
355
356 return Inv;
357 }
365 inline K det() const {
366 if (rows != cols)
367 throw std::invalid_argument("Matrix must be square");
368
369 const size_t n = rows;
370 Matrix<K> M(n, n);
372 const size_t simd_width = SimdT::width;
373
374 for (size_t i = 0; i < n; ++i)
375 for (size_t j = 0; j < n; ++j)
376 M(i, j) = operator()(i, j);
377
378 K det_sign = K(1);
379
380 for (size_t i = 0; i < n; ++i) {
381 size_t piv = i;
382 auto maxv = MathsUtils::_abs(M(i, i));
383 for (size_t r = i + 1; r < n; ++r) {
384 auto v = MathsUtils::_abs(M(r, i));
385 if (v > maxv) {
386 maxv = v;
387 piv = r;
388 }
389 }
390 if (maxv < static_cast<K>(1e-12))
391 return K(0);
392
393 if (piv != i) {
394 M.swap_rows(i, piv);
395 det_sign = -det_sign;
396 }
397
398 for (size_t j = i + 1; j < n; ++j) {
399 auto f = M(j, i) / M(i, i);
400 M(j, i) = K(0);
401
402 auto f_vec = SimdT::set1(-f);
403 size_t k = i + 1;
404
405 for (; k + simd_width - 1 < n; k += simd_width) {
406 auto mjk = SimdT::load(&M(j, k));
407 auto mik = SimdT::load(&M(i, k));
408 mjk = SimdT::fmadd(f_vec, mik, mjk);
409 SimdT::store(&M(j, k), mjk);
410 }
411 for (; k < n; ++k) {
412 M(j, k) -= f * M(i, k);
413 }
414 }
415 }
416
417 K det = det_sign;
418 for (size_t i = 0; i < n; ++i)
419 det *= M(i, i);
420
421 return det;
422 }
428 inline size_t rank(K eps = K(1e-6)) const {
429 Matrix<K> M(*this);
430 const size_t m = rows;
431 const size_t n = cols;
432 size_t r = 0;
433
434 for (size_t col = 0; col < n; ++col) {
435 size_t pivot_row = r;
436 for (size_t i = r; i < m; ++i) {
437 if (MathsUtils::_abs(M(i, col)) > MathsUtils::_abs(M(pivot_row, col)))
438 pivot_row = i;
439 }
440
441 if (MathsUtils::_abs(M(pivot_row, col)) <= eps)
442 continue;
443
444 if (pivot_row != r)
445 M.swap_rows(pivot_row, r);
446
447 for (size_t i = r + 1; i < m; ++i) {
448 auto f = M(i, col) / M(r, col);
449 M(i, col) = 0;
450 for (size_t j = col + 1; j < n; ++j)
451 M(i, j) -= f * M(r, j);
452 }
453
454 ++r;
455 }
456
457 return r;
458 }
459};
460} // namespace tensorium
std::vector< K, AlignedAllocator< K, ALIGN > > aligned_vector
Type alias for a std::vector with aligned memory allocation.
Definition Allocator.hpp:111
size_t detect_optimal_block_size()
Definition CPU_id.hpp:18
static void _swap(T &a, T &b)
Definition MathsUtils.hpp:26
static double _abs(double a)
Definition MathsUtils.hpp:32
Definition GemmKernel_bigger.hpp:16
void matmul(T *A, T *B, T *C, int M, int N, int K)
Definition GemmKernel_bigger.hpp:828
High-performance aligned matrix class with SIMD support.
Definition Matrix.hpp:27
Matrix< K > transpose() const
Returns the transpose of the matrix (column-major layout)
Definition Matrix.hpp:272
const K & operator()(size_t i, size_t j) const
Definition Matrix.hpp:57
Matrix(size_t r, size_t c)
Construct a matrix of size r × c, initialized with zeros.
Definition Matrix.hpp:36
Matrix _mul_mat(const Matrix< K > &mat) const
Multiply matrix by another matrix using optimized SIMD path.
Definition Matrix.hpp:212
size_t rows
Definition Matrix.hpp:29
size_t rank(K eps=K(1e-6)) const
Compute the numerical rank of the matrix.
Definition Matrix.hpp:428
void sub(const Matrix &m)
In-place matrix subtraction: this -= m.
Definition Matrix.hpp:125
Vector< T > operator*(const Vector< T > &v) const
Multiply matrix by a vector (naïve fallback)
Definition Matrix.hpp:82
void scl(K a)
In-place scalar multiplication: this *= a.
Definition Matrix.hpp:152
void lerp(const Matrix< K > &A, const Matrix< K > &B, K alpha)
Linearly interpolate between two matrices: this = (1 - α) * A + α * B.
Definition Matrix.hpp:176
Vector< T > mul_vec(const Vector< T > &x) const
Multiply matrix by a vector using SIMD.
Definition Matrix.hpp:236
size_t simd_width
Definition Matrix.hpp:50
K det() const
Compute the determinant using Gaussian elimination.
Definition Matrix.hpp:365
bool iscolumn
Definition Matrix.hpp:32
Matrix< K > trace() const
Returns the trace of a square matrix as a 1×1 matrix.
Definition Matrix.hpp:283
void print() const
Print the matrix to stdout.
Definition Matrix.hpp:60
size_t block_size
Definition Matrix.hpp:31
K & operator()(size_t i, size_t j)
Element access (mutable)
Definition Matrix.hpp:55
aligned_vector< K > data
Definition Matrix.hpp:30
size_t index(size_t i, size_t j) const
Definition Matrix.hpp:42
typename Simd::reg reg
Definition Matrix.hpp:49
void add(const Matrix &m)
In-place matrix addition: this += m.
Definition Matrix.hpp:96
size_t cols
Definition Matrix.hpp:29
Matrix< K > inverse() const
Compute the inverse of the matrix using Gauss–Jordan elimination.
Definition Matrix.hpp:303
void swap_rows(size_t i, size_t j)
Swap two rows of the matrix.
Definition Matrix.hpp:69
size_t size() const
Return the total number of elements.
Definition Matrix.hpp:52
Aligned, SIMD-optimized mathematical vector class for scientific computing.
Definition Vector.hpp:26
size_t size() const
Definition Vector.hpp:76
Definition Derivate.hpp:24
Definition SIMD.hpp:177