Skip to content

Commit

Permalink
vectorized numerics
Browse files Browse the repository at this point in the history
  • Loading branch information
petiaccja committed Apr 15, 2021
1 parent c7ceff7 commit 222a798
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 164 deletions.
2 changes: 1 addition & 1 deletion include/dspbb/Filtering/WindowFunctions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void HammingWindow(SignalView<T, Domain> out) {

std::iota(out.begin(), out.end(), U(0.0));
out *= preSize;
Cos(out);
Cos(out, out);
out *= U(-0.46);
out += U(0.54);
}
Expand Down
216 changes: 93 additions & 123 deletions include/dspbb/Math/Functions.hpp
Original file line number Diff line number Diff line change
@@ -1,165 +1,135 @@
#pragma once

#include "../Primitives/Signal.hpp"
#include "../Primitives/SignalView.hpp"
#include "../Utility/Algorithm.hpp"

#include <complex>
#include <type_traits>
#include <dspbb/Primitives/Signal.hpp>
#include <dspbb/Primitives/SignalTraits.hpp>
#include <dspbb/Primitives/SignalView.hpp>
#include <dspbb/Utility/Algorithm.hpp>
#include <dspbb/Vectorization/ComplexFunctions.hpp>
#include <dspbb/Vectorization/MathFunctions.hpp>
#include <type_traits>


namespace dspbb {


//------------------------------------------------------------------------------
// Complex number functions
//------------------------------------------------------------------------------
#define DSPBB_IMPL_FUNCTION_2_PARAM(NAME, FUNC) \
template <class SignalT, class SignalU, std::enable_if_t<is_mutable_signal_v<SignalT> && is_same_domain_v<std::decay_t<SignalT>, std::decay_t<SignalU>>, int> = 0> \
auto NAME(SignalT&& out, const SignalU& in) { \
return UnaryOperationVectorized(out.Data(), in.Data(), out.Length(), [](auto v) { return math_functions::FUNC(v); }); \
}

#define DSPBB_IMPL_FUNCTION_1_PARAM(NAME) \
template <class SignalT, std::enable_if_t<is_signal_like_v<std::decay_t<SignalT>>, int> = 0> \
auto NAME(const SignalT& signal) { \
SignalT r(signal.Size()); \
NAME(r, signal); \
return r; \
}

template <class SignalT>
auto Abs(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::abs(v); });
}
#define DSPBB_IMPL_FUNCTION(NAME, FUNC) \
DSPBB_IMPL_FUNCTION_2_PARAM(NAME, FUNC) \
DSPBB_IMPL_FUNCTION_1_PARAM(NAME)

template <class SignalT, std::enable_if_t<is_complex_v<typename std::decay_t<SignalT>::value_type>, int> = 0>
auto Arg(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::arg(v); });
}

template <class SignalT, std::enable_if_t<!is_complex_v<typename std::decay_t<SignalT>::value_type>, int> = 0>
auto Real(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return v; });
}

template <class SignalT, std::enable_if_t<is_complex_v<typename std::decay_t<SignalT>::value_type>, int> = 0>
auto Real(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::real(v); });
}

template <class SignalT, std::enable_if_t<is_complex_v<typename std::decay_t<SignalT>::value_type>, int> = 0>
auto Imag(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::imag(v); });
}
//------------------------------------------------------------------------------
// Complex number functions
//------------------------------------------------------------------------------

#define DSPBB_IMPL_COMPLEX_FUNCTION_2_PARAM(NAME, VECOP, OP, FUNC) \
template <class SignalT, \
class SignalU, \
class T, \
std::enable_if_t<is_mutable_signal_v<SignalT> && is_same_domain_v<std::decay_t<SignalT>, std::decay_t<SignalU>>, int> = 0> \
auto NAME(SignalT&& out, const SignalU& in, int, std::complex<T>) { \
\
return UnaryOperationVectorized(out.Data(), \
in.Data(), \
out.Length(), \
complex_functions::VECOP<T>::stride, \
complex_functions::VECOP<T>{}, \
complex_functions::OP<T>{}); \
} \
\
template <class SignalT, class SignalU, std::enable_if_t<is_mutable_signal_v<SignalT> && is_same_domain_v<std::decay_t<SignalT>, std::decay_t<SignalU>>, int> = 0> \
auto NAME(SignalT&& out, const SignalU& in, int, ...) { \
\
return UnaryOperationVectorized(out.Data(), in.Data(), out.Length(), [](auto v) { return math_functions::FUNC(v); }); \
} \
\
template <class SignalT, class SignalU, std::enable_if_t<is_mutable_signal_v<SignalT> && is_same_domain_v<std::decay_t<SignalT>, std::decay_t<SignalU>>, int> = 0> \
auto NAME(SignalT&& out, const SignalU& in) { \
return NAME(std::forward<SignalT>(out), in, 0, typename signal_traits<SignalU>::type{}); \
}

#define DSPBB_IMPL_COMPLEX_FUNCTION_1_PARAM(NAME, FUNC) \
template <class SignalT, std::enable_if_t<is_signal_like_v<std::decay_t<SignalT>>, int> = 0> \
auto NAME(const SignalT& signal) { \
using R = decltype(std::FUNC(std::declval<typename signal_traits<SignalT>::type>())); \
Signal<R, signal_traits<SignalT>::domain> r(signal.Size()); \
NAME(r, signal); \
return r; \
}

#define DSPBB_IMPL_COMPLEX_FUNCTION(NAME, VECOP, OP, FUNC) \
DSPBB_IMPL_COMPLEX_FUNCTION_2_PARAM(NAME, VECOP, OP, FUNC) \
DSPBB_IMPL_COMPLEX_FUNCTION_1_PARAM(NAME, FUNC)

DSPBB_IMPL_COMPLEX_FUNCTION(Abs, AbsVec, Abs, abs)
DSPBB_IMPL_COMPLEX_FUNCTION(Arg, ArgVec, Arg, arg)
DSPBB_IMPL_COMPLEX_FUNCTION(Real, RealVec, Real, real)
DSPBB_IMPL_COMPLEX_FUNCTION(Imag, ImagVec, Imag, imag)

//------------------------------------------------------------------------------
// Exponential functions
//------------------------------------------------------------------------------

template <class SignalT>
auto Log(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::log(v); });
}

template <class SignalT>
auto Log2(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::log2(v); });
}

template <class SignalT>
auto Log10(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::log10(v); });
}

template <class SignalT>
auto Exp(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::exp(v); });
}

DSPBB_IMPL_FUNCTION(Log, log)
DSPBB_IMPL_FUNCTION(Log2, log2)
DSPBB_IMPL_FUNCTION(Log10, log10)
DSPBB_IMPL_FUNCTION(Exp, exp)

//------------------------------------------------------------------------------
// Polynomial functions
//------------------------------------------------------------------------------


template <class SignalT>
auto Pow(SignalT&& signal, typename std::decay_t<SignalT>::value_type power) {
return Apply(
std::forward<SignalT>(signal),
[](typename std::decay_t<SignalT>::value_type v, typename std::decay_t<SignalT>::value_type power) { return std::pow(v, power); }, power);
}

template <class SignalT>
auto Sqrt(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::sqrt(v); });
template <class SignalT, class SignalU, std::enable_if_t<is_mutable_signal_v<SignalT> && is_same_domain_v<std::decay_t<SignalT>, std::decay_t<SignalU>>, int> = 0>
auto Pow(SignalT&& out, const SignalU& in, std::remove_const_t<typename std::decay_t<SignalU>::value_type> power) {
return UnaryOperationVectorized(out.Data(), in.Data(), out.Length(), [power](auto v) { return math_functions::pow(v, power); });
}

template <class SignalT>
auto Cbrt(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::cbrt(v); });
template <class SignalT, std::enable_if_t<is_signal_like_v<std::decay_t<SignalT>>, int> = 0>
auto Pow(const SignalT& signal, std::remove_const_t<typename std::decay_t<SignalT>::value_type> power) {
SignalT r(signal.Size());
Pow(r, signal, power);
return r;
}

DSPBB_IMPL_FUNCTION(Sqrt, sqrt)
DSPBB_IMPL_FUNCTION(Cbrt, cbrt)

//------------------------------------------------------------------------------
// Trigonometric functions
//------------------------------------------------------------------------------

template <class SignalT>
auto Sin(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::sin(v); });
}

template <class SignalT>
auto Cos(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::cos(v); });
}

template <class SignalT>
auto Tan(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::tan(v); });
}

template <class SignalT>
auto Asin(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::asin(v); });
}

template <class SignalT>
auto Acos(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::acos(v); });
}

template <class SignalT>
auto Atan(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::atan(v); });
}
DSPBB_IMPL_FUNCTION(Sin, sin)
DSPBB_IMPL_FUNCTION(Cos, cos)
DSPBB_IMPL_FUNCTION(Tan, tan)
DSPBB_IMPL_FUNCTION(Asin, asin)
DSPBB_IMPL_FUNCTION(Acos, acos)
DSPBB_IMPL_FUNCTION(Atan, atan)


//------------------------------------------------------------------------------
// Hyperbolic functions
//------------------------------------------------------------------------------

template <class SignalT>
auto Sinh(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::sinh(v); });
}

template <class SignalT>
auto Cosh(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::cosh(v); });
}


template <class SignalT>
auto Tanh(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::tanh(v); });
}

template <class SignalT>
auto Asinh(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::asinh(v); });
}

template <class SignalT>
auto Acosh(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::acosh(v); });
}

template <class SignalT>
auto Atanh(SignalT&& signal) {
return Apply(std::forward<SignalT>(signal), [](typename std::decay_t<SignalT>::value_type v) { return std::atanh(v); });
}
DSPBB_IMPL_FUNCTION(Sinh, sinh)
DSPBB_IMPL_FUNCTION(Cosh, cosh)
DSPBB_IMPL_FUNCTION(Tanh, tanh)
DSPBB_IMPL_FUNCTION(Asinh, asinh)
DSPBB_IMPL_FUNCTION(Acosh, acosh)
DSPBB_IMPL_FUNCTION(Atanh, atanh)


} // namespace dspbb
91 changes: 91 additions & 0 deletions include/dspbb/Vectorization/ComplexFunctions.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#pragma once

#include <xsimd/xsimd.hpp>

namespace dspbb {
namespace complex_functions {

template <class T>
struct AbsVec {
static constexpr size_t stride = xsimd::simd_traits<std::complex<T>>::size;
using complex_vector = xsimd::batch<std::complex<T>, stride>;
using real_vector = xsimd::batch<T, stride>;
void operator()(T* out, const std::complex<T>* in) {
complex_vector vin;
vin.load_unaligned(in);
const real_vector vout = xsimd::abs(vin);
vout.store_unaligned(out);
}
};

template <class T>
struct Abs {
void operator()(T* out, const std::complex<T>* in) {
*out = std::abs(*in);
}
};


template <class T>
struct ArgVec {
static constexpr size_t stride = xsimd::simd_traits<std::complex<T>>::size;
using complex_vector = xsimd::batch<std::complex<T>, stride>;
using real_vector = xsimd::batch<T, stride>;
void operator()(T* out, const std::complex<T>* in) {
complex_vector vin;
vin.load_unaligned(in);
const real_vector vout = xsimd::arg(vin);
vout.store_unaligned(out);
}
};

template <class T>
struct Arg {
void operator()(T* out, const std::complex<T>* in) {
*out = std::arg(*in);
}
};


template <class T>
struct RealVec {
static constexpr size_t stride = xsimd::simd_traits<std::complex<T>>::size;
using complex_vector = xsimd::batch<std::complex<T>, stride>;
using real_vector = xsimd::batch<T, stride>;
void operator()(T* out, const std::complex<T>* in) {
complex_vector vin;
vin.load_unaligned(in);
const real_vector vout = xsimd::real(vin);
vout.store_unaligned(out);
}
};

template <class T>
struct Real {
void operator()(T* out, const std::complex<T>* in) {
*out = std::real(*in);
}
};

template <class T>
struct ImagVec {
static constexpr size_t stride = xsimd::simd_traits<std::complex<T>>::size;
using complex_vector = xsimd::batch<std::complex<T>, stride>;
using real_vector = xsimd::batch<T, stride>;
void operator()(T* out, const std::complex<T>* in) {
complex_vector vin;
vin.load_unaligned(in);
const real_vector vout = xsimd::imag(vin);
vout.store_unaligned(out);
}
};

template <class T>
struct Imag {
void operator()(T* out, const std::complex<T>* in) {
*out = std::imag(*in);
}
};

} // namespace complex_functions
} // namespace dspbb
15 changes: 15 additions & 0 deletions include/dspbb/Vectorization/Kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,20 @@ void UnaryOperationVectorized(T* out, T* in, size_t length, Op op) {
UnaryOperation(out, in, length - vlength, op);
}

template <class R, class T, class Op, class VecOp>
void UnaryOperationVectorized(R* out, T* in, size_t length, size_t stride, VecOp vop, Op op) {
const size_t vlength = (length / stride) * stride;

const R* vlast = out + vlength;
const R* last = out + length;

for (; out < vlast; out += stride, in += stride) {
vop(out, in);
}
for (; out < last; out += 1, in += 1) {
op(out, in);
}
}


} // namespace dspbb
Loading

0 comments on commit 222a798

Please sign in to comment.