25template <
typename K, std::
size_t Rank>
class Tensor {
41 Tensor(
const std::array<size_t, Rank> &dims)
46 for (int64_t
i = Rank - 2;
i >= 0; --
i) {
51 for (
size_t i = 0;
i < Rank; ++
i)
58 std::array<size_t, 4> idx = {
i, j, k, l};
65 for (int64_t
i = Rank - 2;
i >= 0; --
i)
70 void resize(
const std::array<size_t, Rank> &dims) {
75 for (
size_t i = 0;
i < Rank; ++
i)
83 void resize(
size_t d0,
size_t d1) {
resize(std::array<size_t, 2>{d0, d1}); }
87 void resize(
size_t d0,
size_t d1,
size_t d2) {
88 static_assert(Rank == 3,
"Rank mismatch in resize()");
91 data.resize(d0 * d1 * d2, K(0));
99 const K &
operator()(
const std::array<size_t, Rank> &indices)
const {
104 const K &
operator()(
size_t i,
size_t j,
size_t k,
size_t l)
const {
105 std::array<size_t, 4> idx = {
i, j, k, l};
110 std::array<size_t, 4> idx = {
i, j, k, l};
115 static_assert(Rank == 3,
"Rank mismatch in operator()");
119 static_assert(Rank == 3,
"Rank mismatch in operator()");
130 std::cout <<
"Tensor shape: (";
131 for (
size_t i = 0;
i < Rank; ++
i) {
143 std::cout << std::setw(10) << std::setprecision(4) << std::fixed << (*this)({
i, j})
161 flatten_index_simd(
const size_t *indices,
const size_t *
strides)
const {
163 using reg =
typename Simd::reg;
164 constexpr size_t W = Simd::width /
sizeof(size_t);
169 for (;
i +
W - 1 < Rank;
i +=
W) {
170 reg idx = Simd::load(&indices[
i]);
172 reg prod = Simd::mul(idx, str);
173 acc += detail::reduce_sum(prod);
175 for (;
i < Rank; ++
i)
189 flatten_index(
const std::array<size_t, Rank> &indices)
const {
190 return flatten_index_simd(indices.data(),
strides.data());
207 static_assert(Rank >= 2,
"Cannot contract tensor of rank < 2");
208 assert(
i < Rank &&
j < Rank &&
i !=
j);
209 assert(t.dimensions[
i] == t.dimensions[
j]);
213 using reg =
typename SimdValue::reg;
214 constexpr size_t W = SimdValue::width;
218 for (
size_t d = 0;
d < Rank; ++
d) {
219 if (
d !=
i &&
d !=
j)
224 std::array<size_t, Rank>
indices{};
227 for (
size_t flat = 0;
flat < result.data.size(); ++
flat) {
230 for (
size_t d = 0;
d < Rank; ++
d) {
231 if (
d ==
i ||
d ==
j) {
239 const size_t dim = t.dimensions[
i];
241 reg acc = SimdValue::zero();
243 for (;
k +
W - 1 < dim;
k +=
W) {
244 alignas(64)
size_t k_vec[
W];
245 for (
size_t w = 0;
w <
W; ++
w) {
252 for (
size_t w = 0;
w <
W; ++
w)
259 K sum = detail::reduce_sum(
acc);
260 for (;
k < dim; ++
k) {
267 for (
size_t d = 0;
d < Rank; ++
d) {
268 if (
d !=
i &&
d !=
j)
294 using reg =
typename Simd::reg;
295 constexpr size_t W = Simd::width /
sizeof(
K);
297 for (
size_t i = 0;
i < rows; ++
i) {
299 for (;
j +
W - 1 < cols;
j +=
W) {
300 reg vec = Simd::load(&(*
this)({
i,
j}));
304 for (
size_t k = 0;
k <
W; ++
k)
307 for (;
j < cols; ++
j)
308 result({
j,
i}) = (*
this)({
i,
j});
329 template <
size_t R1,
size_t R2>
333 using reg =
typename Simd::reg;
334 constexpr size_t W = Simd::width /
sizeof(
K);
335 constexpr size_t R =
R1 +
R2;
339 std::array<size_t, R>
shape;
340 for (
size_t i = 0;
i <
R1; ++
i)
342 for (
size_t i = 0;
i <
R2; ++
i)
348 std::fill(result.
data.begin(), result.
data.end(),
K(0));
349#pragma omp parallel for collapse(2)
366 std::array<size_t, R1>
idx_A;
380 for (
size_t w = 0;
w <
W; ++
w) {
381 std::array<size_t, R2>
idx_B;
382 std::array<size_t, R>
idx_C;
389#pragma unroll(R1 + R2 - 1)
390 for (
size_t i = 0;
i <
R1; ++
i)
392#pragma unroll(R1 + R2 - 1)
393 for (
size_t i = 0;
i <
R2; ++
i)
401 std::array<size_t, R2>
idx_B;
402 std::array<size_t, R>
idx_C;
409 for (
size_t i = 0;
i <
R1; ++
i)
411 for (
size_t i = 0;
i <
R2; ++
i)
std::vector< K, AlignedAllocator< K, ALIGN > > aligned_vector
Type alias for a std::vector with aligned memory allocation.
Definition Allocator.hpp:111
Multi-dimensional tensor class with fixed rank and SIMD support.
Definition Tensor.hpp:25
size_t block_size
Definition Tensor.hpp:32
void resize(size_t d0, size_t d1, size_t d2)
Definition Tensor.hpp:87
void resize(const std::array< size_t, Rank > &dims)
Resize 2D tensor.
Definition Tensor.hpp:70
K & operator()(const std::array< size_t, Rank > &indices)
Definition Tensor.hpp:94
void update_strides()
Definition Tensor.hpp:63
size_t total_size
Definition Tensor.hpp:30
K value_type
Definition Tensor.hpp:27
void print() const
Print a 2D tensor (Rank == 2)
Definition Tensor.hpp:140
constexpr size_t W
Definition Tensor.hpp:164
__attribute__((always_inline, hot, flatten)) inline size_t flatten_index(const std __attribute__((always_inline, hot, flatten)) Tensor< K
Convert multi-index to flat index.
__attribute__((always_inline, hot, flatten)) inline size_t flatten_index_simd(const size_t *indices
Convert a multi-index into a flattened linear index using SIMD.
K & operator()(size_t i, size_t j, size_t k, size_t l)
Definition Tensor.hpp:109
__attribute__((always_inline, hot, flatten)) inline size_t flatten_index(const std transpose_simd() const
Definition Tensor.hpp:287
size_t acc
Definition Tensor.hpp:166
Tensor()
Default constructor.
Definition Tensor.hpp:35
K & operator()(size_t i, size_t j, size_t k)
Definition Tensor.hpp:114
typename Simd::reg reg
Definition Tensor.hpp:163
size_t flatten_index(size_t i, size_t j, size_t k, size_t l) const
Definition Tensor.hpp:57
std::array< size_t, Rank > strides
Definition Tensor.hpp:33
std::array< size_t, Rank > shape() const
Definition Tensor.hpp:62
aligned_vector< K > data
Definition Tensor.hpp:31
const K & operator()(size_t i, size_t j) const
Definition Tensor.hpp:113
Tensor(const std::array< size_t, Rank > &dims)
Construct tensor with given dimensions.
Definition Tensor.hpp:41
std::array< size_t, Rank > dimensions
Dimensions of the tensor (e.g., {4,4,4,4})
Definition Tensor.hpp:29
void resize(size_t d0, size_t d1)
Resize 2D tensor.
Definition Tensor.hpp:83
const K & operator()(size_t i, size_t j, size_t k) const
Definition Tensor.hpp:118
const K & operator()(const std::array< size_t, Rank > &indices) const
Definition Tensor.hpp:99
size_t i
Definition Tensor.hpp:167
void fill(K value)
Fill tensor with a constant value.
Definition Tensor.hpp:125
void print_shape() const
Print the shape (dimensions) of the tensor.
Definition Tensor.hpp:129
return acc
Definition Tensor.hpp:178
static Tensor< K, R1+R2 > tensor_product(const Tensor< K, R1 > &A, const Tensor< K, R2 > &B)
Compute the tensor (outer) product of two tensors.
Definition Tensor.hpp:330
K & operator()(size_t i, size_t j)
Definition Tensor.hpp:108
const K & operator()(size_t i, size_t j, size_t k, size_t l) const
Definition Tensor.hpp:104
Definition Derivate.hpp:24
Tensor< K, Rank - 2 > contract_tensor(const Tensor< K, Rank > &T)
Definition Functional.hpp:327