Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pure xtensor FFT implementation #2782

Merged
merged 3 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ __pycache__

# Generated files
*.pc
.vscode/settings.json
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ set(XTENSOR_HEADERS
${XTENSOR_INCLUDE_DIR}/xtensor/xfixed.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xfunction.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xfunctor_view.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xfft.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xgenerator.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xhistogram.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xindex_view.hpp
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/container_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ xexpression API is actually implemented in ``xstrided_container`` and ``xcontain
xindex_view
xfunctor_view
xrepeat
xfft
17 changes: 17 additions & 0 deletions docs/source/xfft.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht
Distributed under the terms of the BSD 3-Clause License.
The full license is in the file LICENSE, distributed with this software.
xfft
====

Defined in ``xtensor/xfft.hpp``

.. doxygenclass:: xt::fft::convolve
:project: xtensor
:members:

.. doxygentypedef:: xt::fft::fft
:project: xtensor

.. doxygentypedef:: xt::fft::ifft
:project: xtensor
241 changes: 241 additions & 0 deletions include/xtensor/xfft.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
#ifdef XTENSOR_USE_TBB
#include <oneapi/tbb.h>
#endif
#include <stdexcept>

#include <xtl/xcomplex.hpp>

#include <xtensor/xarray.hpp>
#include <xtensor/xaxis_slice_iterator.hpp>
#include <xtensor/xbuilder.hpp>
#include <xtensor/xcomplex.hpp>
#include <xtensor/xmath.hpp>
#include <xtensor/xnoalias.hpp>
#include <xtensor/xview.hpp>

namespace xt
{
namespace fft
{
namespace detail
{
template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto radix2(E&& e)
{
using namespace xt::placeholders;
using namespace std::complex_literals;
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
auto N = e.size();
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
// check for power of 2
if (!powerOfTwo || N == 0)
{
// TODO: Replace implementation with dft
XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2");
}
auto pi = xt::numeric_constants<precision>::PI;
xt::xtensor<value_type, 1> ev = e;
if (N <= 1)
{
return ev;
}
else
{
#ifdef XTENSOR_USE_TBB
xt::xtensor<value_type, 1> even;
xt::xtensor<value_type, 1> odd;
oneapi::tbb::parallel_invoke(
[&]
{
even = radix2(xt::view(ev, xt::range(0, _, 2)));
},
[&]
{
odd = radix2(xt::view(ev, xt::range(1, _, 2)));
}
);
#else
auto even = radix2(xt::view(ev, xt::range(0, _, 2)));
auto odd = radix2(xt::view(ev, xt::range(1, _, 2)));
#endif

auto range = xt::arange<double>(N / 2);
auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N);
auto t = exp * odd;
auto first_half = even + t;
auto second_half = even - t;
// TODO: should be a call to stack if performance was improved
auto spectrum = xt::xtensor<value_type, 1>::from_shape({N});
xt::view(spectrum, xt::range(0, N / 2)) = first_half;
xt::view(spectrum, xt::range(N / 2, N)) = second_half;
return spectrum;
}
}

template <typename E>
auto transform_bluestein(E&& data)
{
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;

// Find a power-of-2 convolution length m such that m >= n * 2 + 1
const std::size_t n = data.size();
size_t m = std::ceil(std::log2(n * 2 + 1));
m = std::pow(2, m);

// Trignometric table
auto exp_table = xt::xtensor<std::complex<precision>, 1>::from_shape({n});
xt::xtensor<std::size_t, 1> i = xt::pow(xt::linspace<std::size_t>(0, n - 1, n), 2);
i %= (n * 2);

auto angles = xt::eval(precision{3.141592653589793238463} * i / n);
auto j = std::complex<precision>(0, 1);
exp_table = xt::exp(-angles * j);

// Temporary vectors and preprocessing
auto av = xt::empty<std::complex<precision>>({m});
xt::view(av, xt::range(0, n)) = data * exp_table;


auto bv = xt::empty<std::complex<precision>>({m});
xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table);
xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view(
::xt::conj(xt::flip(exp_table)),
xt::range(xt::placeholders::_, -1)
);

// Convolution
auto xv = radix2(av);
auto yv = radix2(bv);
auto spectrum_k = xv * yv;
auto complex_args = xt::conj(spectrum_k);
auto fft_res = radix2(complex_args);
auto cv = xt::conj(fft_res) / m;

return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table);
}
} // namespace detail

/**
* @brief 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
const auto saxis = xt::normalize_axis(e.dimension(), axis);
const size_t N = e.shape(saxis);
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
xt::xarray<std::complex<precision>> out = xt::eval(e);
auto begin = xt::axis_slice_begin(out, saxis);
auto end = xt::axis_slice_end(out, saxis);
for (auto iter = begin; iter != end; iter++)
{
if (powerOfTwo)
{
xt::noalias(*iter) = detail::radix2(*iter);
}
else
{
xt::noalias(*iter) = detail::transform_bluestein(*iter);
}
}
return out;
}

/**
* @brief 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <
class E,
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay<E>::type::value_type;
return fft(xt::cast<std::complex<value_type>>(e), axis);
}

template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
auto ifft(E&& e, std::ptrdiff_t axis = -1)
{
// check the length of the data on that axis
const std::size_t n = e.shape(axis);
if (n == 0)
{
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
}
auto complex_args = xt::conj(e);
auto fft_res = xt::fft::fft(complex_args, axis);
fft_res = xt::conj(fft_res);
return fft_res;
}

template <
class E,
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay<E>::type::value_type;
return ifft(xt::cast<std::complex<value_type>>(e), axis);
}

/*
* @brief performs a circular fft convolution xvec and yvec must
* be the same shape.
* @param xvec first array of the convolution
* @param yvec second array of the convolution
* @param axis axis along which to perform the convolution
*/
template <typename E1, typename E2>
auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1)
{
// we could broadcast but that could get complicated???
if (xvec.dimension() != yvec.dimension())
{
XTENSOR_THROW(std::runtime_error, "Mismatched dimentions");
}

auto saxis = xt::normalize_axis(xvec.dimension(), axis);
if (xvec.shape(saxis) != yvec.shape(saxis))
{
XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis");
}

const std::size_t n = xvec.shape(saxis);

auto xv = fft(xvec, axis);
auto yv = fft(yvec, axis);

auto begin_x = xt::axis_slice_begin(xv, saxis);
auto end_x = xt::axis_slice_end(xv, saxis);
auto iter_y = xt::axis_slice_begin(yv, saxis);

for (auto iter = begin_x; iter != end_x; iter++)
{
(*iter) = (*iter_y++) * (*iter);
}

auto outvec = ifft(xv, axis);

// Scaling (because this FFT implementation omits it)
outvec = outvec / n;

return outvec;
}

}
} // namespace xt::fft
1 change: 1 addition & 0 deletions include/xtensor/xmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ namespace xt
XTENSOR_UNARY_MATH_FUNCTOR(isfinite);
XTENSOR_UNARY_MATH_FUNCTOR(isinf);
XTENSOR_UNARY_MATH_FUNCTOR(isnan);
XTENSOR_UNARY_MATH_FUNCTOR(conj);
}

#undef XTENSOR_UNARY_MATH_FUNCTOR
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ set(XTENSOR_TESTS
test_xdynamic_view.cpp
test_xfunctor_adaptor.cpp
test_xfixed.cpp
test_xfft.cpp
test_xhistogram.cpp
test_xpad.cpp
test_xindex_view.cpp
Expand Down
86 changes: 86 additions & 0 deletions test/test_xfft.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "xtensor/xarray.hpp"
#include "xtensor/xfft.hpp"

#include "test_common_macros.hpp"

namespace xt
{
TEST(xfft, fft_power_2)
{
size_t k = 2;
size_t n = 8192;
size_t A = 10;
auto x = xt::linspace<float>(0, static_cast<float>(n - 1), n);
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n);
auto res = xt::fft::fft(y) / (n / 2);
REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001));
}

TEST(xfft, ifft_power_2)
{
size_t k = 2;
size_t n = 8;
size_t A = 10;
auto x = xt::linspace<float>(0, static_cast<float>(n - 1), n);
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n);
auto res = xt::fft::ifft(y) / (n / 2);
REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001));
}

TEST(xfft, convolve_power_2)
{
xt::xarray<float> x = {1.0, 1.0, 1.0, 5.0};
xt::xarray<float> y = {5.0, 1.0, 1.0, 1.0};
xt::xarray<float> expected = {12, 12, 12, 28};

auto result = xt::fft::convolve(x, y);

for (size_t i = 0; i < x.size(); i++)
{
REQUIRE(expected(i) == doctest::Approx(std::abs(result(i))).epsilon(.0001));
}
}

TEST(xfft, fft_n_0_axis)
{
size_t k = 2;
size_t n = 10;
size_t A = 1;
size_t dim = 10;
auto x = xt::linspace<float>(0, n - 1, n) * xt::ones<float>({dim, n});
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n);
y = xt::transpose(y);
auto res = xt::fft::fft(y, 0) / (n / 2.0);
REQUIRE(A == doctest::Approx(std::abs(res(k, 0))).epsilon(.0001));
REQUIRE(A == doctest::Approx(std::abs(res(k, 1))).epsilon(.0001));
}

TEST(xfft, fft_n_1_axis)
{
size_t k = 2;
size_t n = 15;
size_t A = 1;
size_t dim = 2;
auto x = xt::linspace<float>(0, n - 1, n) * xt::ones<float>({dim, n});
xt::xarray<float> y = A * xt::sin(2 * xt::numeric_constants<float>::PI * x * k / n);
auto res = xt::fft::fft(y) / (n / 2.0);
REQUIRE(A == doctest::Approx(std::abs(res(0, k))).epsilon(.0001));
REQUIRE(A == doctest::Approx(std::abs(res(1, k))).epsilon(.0001));
}

TEST(xfft, convolve_n)
{
xt::xarray<float> x = {1.0, 1.0, 1.0, 5.0, 1.0};
xt::xarray<float> y = {5.0, 1.0, 1.0, 1.0, 1.0};
xt::xarray<size_t> expected = {13, 13, 13, 29, 13};

auto result = xt::fft::convolve(x, y);

xt::xarray<float> abs = xt::abs(result);

for (size_t i = 0; i < abs.size(); i++)
{
REQUIRE(expected(i) == doctest::Approx(abs(i)).epsilon(.0001));
}
}
}
Loading