diff --git a/include/xtensor/xfft.hpp b/include/xtensor/xfft.hpp index 051ff7407..30ef6b01c 100644 --- a/include/xtensor/xfft.hpp +++ b/include/xtensor/xfft.hpp @@ -3,224 +3,240 @@ #include #endif #include -#include -#include + +#include + #include #include #include +#include +#include #include #include -#include -namespace xt{ -namespace fft { -namespace detail { -template ::type::value_type>::value, - bool>::type = true> -inline auto radix2(E &&e) { - using namespace xt::placeholders; - using namespace std::complex_literals; - using value_type = typename std::decay_t::value_type; - using precision = typename value_type::value_type; - auto N = e.size(); - const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); - // check for power of 2 - if (!powerOfTwo || N == 0) { - // TODO: Replace implementation with dft - XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2"); - } - auto pi = xt::numeric_constants::PI; - xt::xtensor ev = e; - if (N <= 1) { - return ev; - } else { +namespace xt +{ + namespace fft + { + namespace detail + { + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + inline auto radix2(E&& e) + { + using namespace xt::placeholders; + using namespace std::complex_literals; + using value_type = typename std::decay_t::value_type; + using precision = typename value_type::value_type; + auto N = e.size(); + const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); + // check for power of 2 + if (!powerOfTwo || N == 0) + { + // TODO: Replace implementation with dft + XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2"); + } + auto pi = xt::numeric_constants::PI; + xt::xtensor ev = e; + if (N <= 1) + { + return ev; + } + else + { #ifdef XTENSOR_USE_TBB - xt::xtensor even; - xt::xtensor odd; - oneapi::tbb::parallel_invoke( - [&] { even = radix2(xt::view(ev, xt::range(0, _, 2))); }, - [&] { odd = radix2(xt::view(ev, xt::range(1, _, 2))); }); + xt::xtensor even; + xt::xtensor odd; + oneapi::tbb::parallel_invoke( + [&] + { + even = radix2(xt::view(ev, xt::range(0, _, 2))); + }, + [&] + { + odd = radix2(xt::view(ev, xt::range(1, _, 2))); + } + ); #else - auto even = radix2(xt::view(ev, xt::range(0, _, 2))); - auto odd = radix2(xt::view(ev, xt::range(1, _, 2))); + auto even = radix2(xt::view(ev, xt::range(0, _, 2))); + auto odd = radix2(xt::view(ev, xt::range(1, _, 2))); #endif - auto range = xt::arange(N / 2); - auto exp = xt::exp(static_cast(-2i) * pi * range / N); - auto t = exp * odd; - auto first_half = even + t; - auto second_half = even - t; - // TODO: should be a call to stack if performance was improved - auto spectrum = xt::xtensor::from_shape({N}); - xt::view(spectrum, xt::range(0, N / 2)) = first_half; - xt::view(spectrum, xt::range(N / 2, N)) = second_half; - return spectrum; - } -} -template -auto transform_bluestein(E&& data) -{ - using value_type = typename std::decay_t::value_type; - using precision = typename value_type::value_type; - - // Find a power-of-2 convolution length m such that m >= n * 2 + 1 - const std::size_t n = data.size(); - size_t m = std::ceil(std::log2(n * 2 + 1)); - m = std::pow(2, m); - - // Trignometric table - auto exp_table = xt::xtensor, 1>::from_shape({n}); - xt::xtensor i = xt::pow(xt::linspace(0, n - 1, n), 2); - i %= (n * 2); - - auto angles = xt::eval(::xt::numeric_constants::PI * i / n); - auto j = std::complex(0, 1); - exp_table = xt::exp(-angles * j); - - // Temporary vectors and preprocessing - auto av = xt::empty>({m}); - xt::view(av, xt::range(0, n)) = data * exp_table; - - - auto bv = xt::empty>({m}); - xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table); - xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view( - ::xt::conj(xt::flip(exp_table)), - xt::range(xt::placeholders::_, -1) - ); - - // Convolution - auto xv = radix2(av); - auto yv = radix2(bv); - auto spectrum_k = xv * yv; - auto complex_args = xt::conj(spectrum_k); - auto fft_res = radix2(complex_args); - auto cv = xt::conj(fft_res) / m; - - return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table); -} -} // namespace detail - -/** - * @brief 1D FFT of an Nd array along a specified axis - * @param e an Nd expression to be transformed to the fourier domain - * @param axis the axis along which to perform the 1D FFT - * @return a transformed xarray of the specified precision - */ -template ::type::value_type>::value, - bool>::type = true> -inline auto fft(E &&e, std::ptrdiff_t axis = -1) { - using value_type = typename std::decay_t::value_type; - using precision = typename value_type::value_type; - const auto saxis = xt::normalize_axis(e.dimension(), axis); - const size_t N = e.shape(saxis); - const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); - xt::xarray> out = xt::eval(e); - auto begin = xt::axis_slice_begin(out, saxis); - auto end = xt::axis_slice_end(out, saxis); - for (auto iter = begin; iter != end; iter++) { - if(powerOfTwo) - { - xt::noalias(*iter) = detail::radix2(*iter); - } - else - { - xt::noalias(*iter) = detail::transform_bluestein(*iter); - } - } - return out; -} - -/** - * @breif 1D FFT of an Nd array along a specified axis - * @param e an Nd expression to be transformed to the fourier domain - * @param axis the axis along which to perform the 1D FFT - * @return a transformed xarray of the specified precision - */ -template ::type::value_type>::value, - bool>::type = true> -inline auto fft(E &&e, std::ptrdiff_t axis = -1) { - using value_type = typename std::decay::type::value_type; - return fft(xt::cast>(e), axis); -} - -template ::type::value_type>::value, - bool>::type = true> -auto ifft(E&& e, std::ptrdiff_t axis = -1) -{ - // check the length of the data on that axis - const std::size_t n = e.shape(axis); - if (n == 0) - { - XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention"); - } - auto complex_args = xt::conj(e); - auto fft_res = xt::fft::fft(complex_args, axis); - fft_res = xt::conj(fft_res); - return fft_res; -} - -template ::type::value_type>::value, - bool>::type = true> -inline auto ifft(E &&e, std::ptrdiff_t axis = -1) { - using value_type = typename std::decay::type::value_type; - return ifft(xt::cast>(e), axis); -} - - -/* - * @brief performs a circular fft convolution xvec and yvec must - * be the same shape. - * @param xvec first array of the convolution - * @param yvec second array of the convolution - * @param axis axis along which to perform the convolution - */ -template -auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1) -{ - // we could broadcast but that could get complicated??? - if (xvec.dimension() != yvec.dimension()) - { - XTENSOR_THROW(std::runtime_error, "Mismatched dimentions"); - } - - auto saxis = xt::normalize_axis(xvec.dimension(), axis); - if (xvec.shape(saxis) != yvec.shape(saxis)) - { - XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis"); - } - - const std::size_t n = xvec.shape(saxis); - - auto xv = fft(xvec, axis); - auto yv = fft(yvec, axis); + auto range = xt::arange(N / 2); + auto exp = xt::exp(static_cast(-2i) * pi * range / N); + auto t = exp * odd; + auto first_half = even + t; + auto second_half = even - t; + // TODO: should be a call to stack if performance was improved + auto spectrum = xt::xtensor::from_shape({N}); + xt::view(spectrum, xt::range(0, N / 2)) = first_half; + xt::view(spectrum, xt::range(N / 2, N)) = second_half; + return spectrum; + } + } + + template + auto transform_bluestein(E&& data) + { + using value_type = typename std::decay_t::value_type; + using precision = typename value_type::value_type; + + // Find a power-of-2 convolution length m such that m >= n * 2 + 1 + const std::size_t n = data.size(); + size_t m = std::ceil(std::log2(n * 2 + 1)); + m = std::pow(2, m); + + // Trignometric table + auto exp_table = xt::xtensor, 1>::from_shape({n}); + xt::xtensor i = xt::pow(xt::linspace(0, n - 1, n), 2); + i %= (n * 2); + + auto angles = xt::eval(::xt::numeric_constants::PI * i / n); + auto j = std::complex(0, 1); + exp_table = xt::exp(-angles * j); + + // Temporary vectors and preprocessing + auto av = xt::empty>({m}); + xt::view(av, xt::range(0, n)) = data * exp_table; + + + auto bv = xt::empty>({m}); + xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table); + xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view( + ::xt::conj(xt::flip(exp_table)), + xt::range(xt::placeholders::_, -1) + ); + + // Convolution + auto xv = radix2(av); + auto yv = radix2(bv); + auto spectrum_k = xv * yv; + auto complex_args = xt::conj(spectrum_k); + auto fft_res = radix2(complex_args); + auto cv = xt::conj(fft_res) / m; + + return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table); + } + } // namespace detail + + /** + * @brief 1D FFT of an Nd array along a specified axis + * @param e an Nd expression to be transformed to the fourier domain + * @param axis the axis along which to perform the 1D FFT + * @return a transformed xarray of the specified precision + */ + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + inline auto fft(E&& e, std::ptrdiff_t axis = -1) + { + using value_type = typename std::decay_t::value_type; + using precision = typename value_type::value_type; + const auto saxis = xt::normalize_axis(e.dimension(), axis); + const size_t N = e.shape(saxis); + const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); + xt::xarray> out = xt::eval(e); + auto begin = xt::axis_slice_begin(out, saxis); + auto end = xt::axis_slice_end(out, saxis); + for (auto iter = begin; iter != end; iter++) + { + if (powerOfTwo) + { + xt::noalias(*iter) = detail::radix2(*iter); + } + else + { + xt::noalias(*iter) = detail::transform_bluestein(*iter); + } + } + return out; + } + + /** + * @breif 1D FFT of an Nd array along a specified axis + * @param e an Nd expression to be transformed to the fourier domain + * @param axis the axis along which to perform the 1D FFT + * @return a transformed xarray of the specified precision + */ + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + inline auto fft(E&& e, std::ptrdiff_t axis = -1) + { + using value_type = typename std::decay::type::value_type; + return fft(xt::cast>(e), axis); + } + + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + auto ifft(E&& e, std::ptrdiff_t axis = -1) + { + // check the length of the data on that axis + const std::size_t n = e.shape(axis); + if (n == 0) + { + XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention"); + } + auto complex_args = xt::conj(e); + auto fft_res = xt::fft::fft(complex_args, axis); + fft_res = xt::conj(fft_res); + return fft_res; + } + + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + inline auto ifft(E&& e, std::ptrdiff_t axis = -1) + { + using value_type = typename std::decay::type::value_type; + return ifft(xt::cast>(e), axis); + } + + /* + * @brief performs a circular fft convolution xvec and yvec must + * be the same shape. + * @param xvec first array of the convolution + * @param yvec second array of the convolution + * @param axis axis along which to perform the convolution + */ + template + auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1) + { + // we could broadcast but that could get complicated??? + if (xvec.dimension() != yvec.dimension()) + { + XTENSOR_THROW(std::runtime_error, "Mismatched dimentions"); + } + + auto saxis = xt::normalize_axis(xvec.dimension(), axis); + if (xvec.shape(saxis) != yvec.shape(saxis)) + { + XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis"); + } + + const std::size_t n = xvec.shape(saxis); + + auto xv = fft(xvec, axis); + auto yv = fft(yvec, axis); + + auto begin_x = xt::axis_slice_begin(xv, saxis); + auto end_x = xt::axis_slice_end(xv, saxis); + auto iter_y = xt::axis_slice_begin(yv, saxis); + + for (auto iter = begin_x; iter != end_x; iter++) + { + (*iter) = (*iter_y++) * (*iter); + } + + auto outvec = ifft(xv, axis); + + // Scaling (because this FFT implementation omits it) + outvec = outvec / n; + + return outvec; + } - auto begin_x = xt::axis_slice_begin(xv, saxis); - auto end_x = xt::axis_slice_end(xv, saxis); - auto iter_y = xt::axis_slice_begin(yv, saxis); - - for (auto iter = begin_x; iter != end_x; iter++) - { - (*iter) = (*iter_y++) * (*iter); } - - auto outvec = ifft(xv, axis); - - // Scaling (because this FFT implementation omits it) - outvec = outvec / n; - - return outvec; -} - -} -} // namespace xt::fft +} // namespace xt::fft diff --git a/test/test_xfft.cpp b/test/test_xfft.cpp index e396b8421..d7fd78896 100644 --- a/test/test_xfft.cpp +++ b/test/test_xfft.cpp @@ -1,7 +1,8 @@ -#include "test_common_macros.hpp" #include "xtensor/xarray.hpp" #include "xtensor/xfft.hpp" +#include "test_common_macros.hpp" + namespace xt { TEST(xfft, fft_power_2) @@ -20,7 +21,7 @@ namespace xt size_t k = 2; size_t n = 8; size_t A = 10; - auto x = xt::linspace (0, static_cast(n - 1), n); + auto x = xt::linspace(0, static_cast(n - 1), n); xt::xarray y = A * xt::sin(2 * xt::numeric_constants::PI * x * k / n); auto res = xt::fft::ifft(y) / (n / 2); REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001)); @@ -28,9 +29,9 @@ namespace xt TEST(xfft, convolve_power_2) { - xt::xarray x = { 1.0, 1.0, 1.0, 5.0}; - xt::xarray y = { 5.0, 1.0, 1.0, 1.0}; - xt::xarray expected = { 12, 12, 12, 28 }; + xt::xarray x = {1.0, 1.0, 1.0, 5.0}; + xt::xarray y = {5.0, 1.0, 1.0, 1.0}; + xt::xarray expected = {12, 12, 12, 28}; auto result = xt::fft::convolve(x, y); @@ -60,7 +61,7 @@ namespace xt size_t n = 15; size_t A = 1; size_t dim = 2; - auto x = xt::linspace(0, n - 1, n) * xt::ones({ dim, n }); + auto x = xt::linspace(0, n - 1, n) * xt::ones({dim, n}); xt::xarray y = A * xt::sin(2 * xt::numeric_constants::PI * x * k / n); auto res = xt::fft::fft(y) / (n / 2.0); REQUIRE(A == doctest::Approx(std::abs(res(0, k))).epsilon(.0001)); @@ -69,15 +70,15 @@ namespace xt TEST(xfft, convolve_n) { - xt::xarray x = { 1.0, 1.0, 1.0, 5.0, 1.0 }; - xt::xarray y = { 5.0, 1.0, 1.0, 1.0, 1.0 }; - xt::xarray expected = { 13, 13, 13, 29, 13 }; + xt::xarray x = {1.0, 1.0, 1.0, 5.0, 1.0}; + xt::xarray y = {5.0, 1.0, 1.0, 1.0, 1.0}; + xt::xarray expected = {13, 13, 13, 29, 13}; auto result = xt::fft::convolve(x, y); xt::xarray abs = xt::abs(result); - for(size_t i = 0; i < abs.size(); i++) + for (size_t i = 0; i < abs.size(); i++) { REQUIRE(expected(i) == doctest::Approx(abs(i)).epsilon(.0001)); }