Tensorium
Loading...
Searching...
No Matches
Spectral.hpp
Go to the documentation of this file.
1#pragma once
2
5#include "../SIMD/CPU_id.hpp"
6#include "../SIMD/SIMD.hpp"
7#include "Matrix.hpp"
8#include "Tensor.hpp"
9#include "Vector.hpp"
10#include <array>
11#include <cassert>
12#include <cmath>
13#include <iomanip>
14#include <iostream>
15#include <numbers>
16#include <stdexcept>
17#include <vector>
18
19namespace tensorium {
27template <typename T> class SpectralFFT {
28 public:
31 using C = std::complex<T>;
38 static inline void forward(CVectorT &a) { transform_impl(a, false); }
39
40 static void forward_3D(Tensor<std::complex<T>, 3> &a) {
41 const auto shape = a.shape();
42 const size_t NX = shape[0], NY = shape[1], NZ = shape[2];
43 using CVector = Vector<std::complex<T>>;
44
45#pragma omp parallel for collapse(2)
46 for (size_t i = 0; i < NX; ++i) {
47 for (size_t j = 0; j < NY; ++j) {
48 CVector sliceZ(NZ);
49 for (size_t k = 0; k < NZ; ++k)
50 sliceZ(k) = a(i, j, k);
51
52 forward(sliceZ);
53
54 for (size_t k = 0; k < NZ; ++k)
55 a(i, j, k) = sliceZ(k);
56 }
57 }
58
59#pragma omp parallel for collapse(2)
60 for (size_t i = 0; i < NX; ++i) {
61 for (size_t k = 0; k < NZ; ++k) {
62 CVector sliceY(NY);
63 for (size_t j = 0; j < NY; ++j)
64 sliceY(j) = a(i, j, k);
65
66 forward(sliceY);
67
68 for (size_t j = 0; j < NY; ++j)
69 a(i, j, k) = sliceY(j);
70 }
71 }
72
73#pragma omp parallel for collapse(2)
74 for (size_t j = 0; j < NY; ++j) {
75 for (size_t k = 0; k < NZ; ++k) {
76 CVector sliceX(NX);
77 for (size_t i = 0; i < NX; ++i)
78 sliceX(i) = a(i, j, k);
79
80 forward(sliceX);
81
82 for (size_t i = 0; i < NX; ++i)
83 a(i, j, k) = sliceX(i);
84 }
85 }
86 }
87
93 static inline void backward(CVectorT &a) { transform_impl(a, true); }
94
95 static void backward_3D(Tensor<std::complex<T>, 3> &a) {
96 const auto shape = a.shape();
97 const size_t NX = shape[0], NY = shape[1], NZ = shape[2];
98 using CVector = Vector<std::complex<T>>;
99
100#pragma omp parallel for collapse(2)
101 for (size_t j = 0; j < NY; ++j) {
102 for (size_t k = 0; k < NZ; ++k) {
103 CVector sliceX(NX);
104 for (size_t i = 0; i < NX; ++i)
105 sliceX(i) = a(i, j, k);
106
107 backward(sliceX);
108
109 for (size_t i = 0; i < NX; ++i)
110 a(i, j, k) = sliceX(i);
111 }
112 }
113
114#pragma omp parallel for collapse(2)
115 for (size_t i = 0; i < NX; ++i) {
116 for (size_t k = 0; k < NZ; ++k) {
117 CVector sliceY(NY);
118 for (size_t j = 0; j < NY; ++j)
119 sliceY(j) = a(i, j, k);
120
121 backward(sliceY);
122
123 for (size_t j = 0; j < NY; ++j)
124 a(i, j, k) = sliceY(j);
125 }
126 }
127
128#pragma omp parallel for collapse(2)
129 for (size_t i = 0; i < NX; ++i) {
130 for (size_t j = 0; j < NY; ++j) {
131 CVector sliceZ(NZ);
132 for (size_t k = 0; k < NZ; ++k)
133 sliceZ(k) = a(i, j, k);
134
135 backward(sliceZ);
136
137 for (size_t k = 0; k < NZ; ++k)
138 a(i, j, k) = sliceZ(k);
139 }
140 }
141 const T norm = 1.0 / (NX * NY * NZ);
142#pragma omp parallel for collapse(3)
143 for (size_t i = 0; i < NX; ++i)
144 for (size_t j = 0; j < NY; ++j)
145 for (size_t k = 0; k < NZ; ++k)
146 a(i, j, k) *= norm;
147 }
148
149 private:
155 static void transform_impl(CVectorT &a, bool inverse) {
156 const std::size_t N = a.size();
157 if (N <= 1 || (N & (N - 1)))
158 throw std::invalid_argument("SpectralFFT: size must be power of two");
159
160 bit_reverse(a);
161
162 std::vector<C> twiddles(N / 2);
163 const T sign = inverse ? T(+1) : T(-1);
164 constexpr T pi = T(3.141592653589793238462643383279502884L);
165 for (std::size_t k = 0; k < N / 2; ++k)
166 twiddles[k] = {std::cos(sign * 2 * pi * k / N), std::sin(sign * 2 * pi * k / N)};
167
168 for (std::size_t len = 2; len <= N; len <<= 1) {
169 const std::size_t half = len >> 1;
170 const std::size_t step = N / len;
171
172#pragma omp parallel for schedule(static) if (len >= 64)
173 for (std::size_t i = 0; i < N; i += len) {
174 for (std::size_t j = 0; j < half; ++j) {
175 C t = a[i + j + half] * twiddles[j * step];
176 C u = a[i + j];
177 a[i + j] = u + t;
178 a[i + j + half] = u - t;
179 }
180 }
181 }
182
183 if (inverse) {
184 const T invN = T(1) / T(N);
185#pragma omp parallel for schedule(static)
186 for (std::size_t i = 0; i < N; ++i)
187 a[i] *= invN;
188 }
189 }
194 static void bit_reverse(CVectorT &a) {
195 const std::size_t N = a.size();
196 for (std::size_t i = 1, j = 0; i < N; ++i) {
197 std::size_t bit = N >> 1;
198 for (; j & bit; bit >>= 1)
199 j ^= bit;
200 j ^= bit;
201 if (i < j)
202 std::swap(a[i], a[j]);
203 }
204 }
205};
213template <typename T> class SpectalChebyshev {
214 public:
217
227 static void compute(const VectorT &X, T h, Tensor2D &result) {
228 const size_t dim = 4;
229 result.resize(dim, dim);
230 result.fill(T(0));
231
232 for (size_t i = 0; i < dim; ++i) {
233 for (size_t j = 0; j < dim; ++j) {
234 result(i, j) = std::cos(X(i) * X(j)) * h;
235 }
236 }
237 }
238};
239} // namespace tensorium
static FrontendPluginRegistry::Add< TensoriumPluginAction > X("tensorium-dispatch", "Handle #pragma tensorium directives")
Register the plugin under the name "tensorium-dispatch".
Placeholder Chebyshev spectral method class.
Definition Spectral.hpp:213
static void compute(const VectorT &X, T h, Tensor2D &result)
Dummy computation using Chebyshev-like cosine weights.
Definition Spectral.hpp:227
Fast Fourier Transform (FFT) implementation using Cooley–Tukey algorithm.
Definition Spectral.hpp:27
static void forward_3D(Tensor< std::complex< T >, 3 > &a)
Definition Spectral.hpp:40
static void transform_impl(CVectorT &a, bool inverse)
Internal FFT implementation (shared by forward/backward)
Definition Spectral.hpp:155
static void backward_3D(Tensor< std::complex< T >, 3 > &a)
Definition Spectral.hpp:95
std::complex< T > C
Definition Spectral.hpp:31
static void backward(CVectorT &a)
Perform inverse FFT (in-place)
Definition Spectral.hpp:93
static void bit_reverse(CVectorT &a)
Bit-reversal permutation step.
Definition Spectral.hpp:194
static void forward(CVectorT &a)
Perform forward FFT (in-place)
Definition Spectral.hpp:38
void resize(const std::array< size_t, Rank > &dims)
Resize 2D tensor.
Definition Tensor.hpp:70
void fill(K value)
Fill tensor with a constant value.
Definition Tensor.hpp:125
size_t size() const
Definition Vector.hpp:76
Definition Derivate.hpp:24