-
Notifications
You must be signed in to change notification settings - Fork 9
/
pocketfft_demo.cc
79 lines (73 loc) · 2.05 KB
/
pocketfft_demo.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <complex>
#include <cmath>
#include <vector>
#include <iostream>
#include "pocketfft_hdronly.h"
using namespace std;
using namespace pocketfft;
template<typename T> void crand(vector<complex<T>> &v)
{
for (auto & i:v)
i = complex<T>(drand48()-0.5, drand48()-0.5);
}
template<typename T1, typename T2> long double l2err
(const vector<T1> &v1, const vector<T2> &v2)
{
long double sum1=0, sum2=0;
for (size_t i=0; i<v1.size(); ++i)
{
long double dr = v1[i].real()-v2[i].real(),
di = v1[i].imag()-v2[i].imag();
long double t1 = sqrt(dr*dr+di*di), t2 = abs(v1[i]);
sum1 += t1*t1;
sum2 += t2*t2;
}
return sqrt(sum1/sum2);
}
int main()
{
for (size_t len=1; len<8192; ++len)
{
shape_t shape{len};
stride_t stridef(shape.size()), strided(shape.size()), stridel(shape.size());
size_t tmpf=sizeof(complex<float>),
tmpd=sizeof(complex<double>),
tmpl=sizeof(complex<long double>);
for (int i=shape.size()-1; i>=0; --i)
{
stridef[i]=tmpf;
tmpf*=shape[i];
strided[i]=tmpd;
tmpd*=shape[i];
stridel[i]=tmpl;
tmpl*=shape[i];
}
size_t ndata=1;
for (size_t i=0; i<shape.size(); ++i)
ndata*=shape[i];
vector<complex<float>> dataf(ndata);
vector<complex<double>> datad(ndata);
vector<complex<long double>> datal(ndata);
crand(dataf);
for (size_t i=0; i<ndata; ++i)
{
datad[i] = dataf[i];
datal[i] = dataf[i];
}
shape_t axes;
for (size_t i=0; i<shape.size(); ++i)
axes.push_back(i);
auto resl = datal;
auto resd = datad;
auto resf = dataf;
c2c(shape, stridel, stridel, axes, FORWARD,
datal.data(), resl.data(), 1.L);
c2c(shape, strided, strided, axes, FORWARD,
datad.data(), resd.data(), 1.);
c2c(shape, stridef, stridef, axes, FORWARD,
dataf.data(), resf.data(), 1.f);
// c2c(shape, stridel, stridel, axes, POCKETFFT_BACKWARD,
// resl.data(), resl.data(), 1.L/ndata);
cout << l2err(resl, resf) << endl;
}
}