diff --git a/CMakeLists.txt b/CMakeLists.txt index 388f17fc..bd08f699 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,7 @@ option(GTENSOR_BOUNDS_CHECK "Enable per access bounds checking" OFF) option(GTENSOR_ADDRESS_CHECK "Enable address checking for device spans" OFF) option(GTENSOR_SYNC_KERNELS "Enable host sync after assign and launch kernels" OFF) option(GTENSOR_ENABLE_FP16 "Enable 16-bit floating point type gt::float16_t" OFF) +option(GTENSOR_ENABLE_BF16 "Enable 16-bit floating point type gt::bfloat16_t" OFF) if (GTENSOR_ENABLE_FORTRAN) # do this early (here) since later the `enable_language(Fortran)` gives me trouble @@ -343,6 +344,13 @@ if (GTENSOR_ENABLE_FP16) INTERFACE GTENSOR_ENABLE_FP16) endif() +if (GTENSOR_ENABLE_BF16) + message(STATUS "${PROJECT_NAME}: gt::bfloat16_t is ENABLED") + message(STATUS "${PROJECT_NAME}: gt::complex_bfloat16_t is ENABLED") + target_compile_definitions(gtensor_${GTENSOR_DEVICE} + INTERFACE GTENSOR_ENABLE_BF16) +endif() + target_compile_definitions(gtensor_${GTENSOR_DEVICE} INTERFACE GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT=${GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT}) diff --git a/include/gtensor/bfloat16_t.h b/include/gtensor/bfloat16_t.h new file mode 100644 index 00000000..7905fb37 --- /dev/null +++ b/include/gtensor/bfloat16_t.h @@ -0,0 +1,227 @@ +#ifndef GTENSOR_BFLOAT16T_H +#define GTENSOR_BFLOAT16T_H + +#include +#include + +#if __has_include() +#include +#define GTENSOR_BF16_CUDA_HEADER +#elif 0 // TODO check if other bf16 type available +#else +#error "GTENSOR_ENABLE_BF16=ON, but no bfloat16 type available!" +#endif + +#include "macros.h" + +namespace gt +{ + +// ====================================================================== +// bfloat16_t + +class bfloat16_t +{ + +#if defined(GTENSOR_BF16_CUDA_HEADER) + using storage_type = __nv_bfloat16; +#else +#error "GTENSOR_ENABLE_BF16=ON, but no bfloat16 type available!" +#endif + +#if defined(GTENSOR_BF16_CUDA_HEADER) && defined(__CUDA_ARCH__) && \ + (__CUDA_ARCH__ >= 800) + using compute_type = __nv_bfloat16; +#define BFLOAT16T_ON_CUDA_DEVICE +#else + using compute_type = float; +#endif + +public: + bfloat16_t() = default; + GT_INLINE bfloat16_t(float x) : x(x){}; + GT_INLINE bfloat16_t(storage_type x) : x(x){}; + + GT_INLINE const bfloat16_t& operator=(const float f) + { + x = f; + return *this; + } + GT_INLINE compute_type Get() const { return static_cast(x); } + + // update operators [+=, -=, *=, /=] + GT_INLINE bfloat16_t operator+=(const bfloat16_t& y) + { +#if defined(BFLOAT16T_ON_CUDA_DEVICE) + x += y.Get(); +#else + x = this->Get() + y.Get(); +#endif + return *this; + } + GT_INLINE bfloat16_t operator-=(const bfloat16_t& y) + { +#if defined(BFLOAT16T_ON_CUDA_DEVICE) + x -= y.Get(); +#else + x = this->Get() - y.Get(); +#endif + return *this; + } + GT_INLINE bfloat16_t operator*=(const bfloat16_t& y) + { +#if defined(BFLOAT16T_ON_CUDA_DEVICE) + x *= y.Get(); +#else + x = this->Get() * y.Get(); +#endif + return *this; + } + GT_INLINE bfloat16_t operator/=(const bfloat16_t& y) + { +#if defined(BFLOAT16T_ON_CUDA_DEVICE) + x /= y.Get(); +#else + x = this->Get() / y.Get(); +#endif + return *this; + } + +private: + storage_type x; +}; + +// op is unary [+, -] +#define PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(op) \ + GT_INLINE bfloat16_t operator op(const bfloat16_t& rhs) \ + { \ + return bfloat16_t(op rhs.Get()); \ + } + +PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(+); +PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(-); + +// op is binary [+, -, *, /] +#define PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(op) \ + GT_INLINE bfloat16_t operator op(const bfloat16_t& lhs, \ + const bfloat16_t& rhs) \ + { \ + return bfloat16_t(lhs.Get() op rhs.Get()); \ + } + +PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+); +PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-); +PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*); +PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/); + +// op is binary [+, -, *, /] +// fp_type is [float, double] +#define PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(op, fp_type) \ + \ + GT_INLINE fp_type operator op(const bfloat16_t& lhs, const fp_type& rhs) \ + { \ + return static_cast(lhs.Get()) op rhs; \ + } \ + \ + GT_INLINE fp_type operator op(const fp_type& lhs, const bfloat16_t& rhs) \ + { \ + return lhs op static_cast(rhs.Get()); \ + } + +PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, float); +PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, float); +PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, float); +PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, float); + +PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, double); +PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, double); +PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, double); +PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, double); + +// op is binary [==, !=, <, <=, >, >=] +#define PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(op) \ + GT_INLINE bool operator op(const bfloat16_t& lhs, const bfloat16_t& rhs) \ + { \ + return lhs.Get() op rhs.Get(); \ + } + +// op is binary [==, !=, <, <=, >, >=] +// fp_type is [float, double] +#define PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(op, fp_type) \ + \ + GT_INLINE bool operator op(const bfloat16_t& lhs, const fp_type& rhs) \ + { \ + return static_cast(lhs.Get()) op rhs; \ + } \ + \ + GT_INLINE bool operator op(const fp_type& lhs, const bfloat16_t& rhs) \ + { \ + return lhs op static_cast(rhs.Get()); \ + } + +PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(==); +PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(!=); +PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(<); +PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(<=); +PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(>); +PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(>=); + +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(==, float); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(!=, float); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<, float); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<=, float); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>, float); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>=, float); + +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(==, double); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(!=, double); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<, double); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<=, double); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>, double); +PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>=, double); + +// op is [==, !=] +// int_type is [int] +#define PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(op, int_type) \ + \ + GT_INLINE bool operator op(const bfloat16_t& lhs, const int_type& rhs) \ + { \ + return lhs op static_cast(rhs); \ + } \ + \ + GT_INLINE bool operator op(const int_type& lhs, const bfloat16_t& rhs) \ + { \ + return static_cast(lhs) op rhs; \ + } + +PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(==, int); +PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(!=, int); + +// function is sqrt +GT_INLINE bfloat16_t sqrt(const bfloat16_t& x) +{ +#if defined(BFLOAT16T_ON_CUDA_DEVICE) + return hsqrt(x.Get()); +#else + return std::sqrt(x.Get()); +#endif +} + +std::ostream& operator<<(std::ostream& s, const bfloat16_t& h) +{ + s << static_cast(h.Get()); + return s; +} + +} // namespace gt + +#undef GTENSOR_BF16_CUDA_HEADER +#undef BFLOAT16T_ON_CUDA_DEVICE +#undef PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR +#undef PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR +#undef PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR +#undef PROVIDE_BFLOAT16T_COMPARISON_OPERATOR +#undef PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR +#undef PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR + +#endif // GTENSOR_BFLOAT16T_H diff --git a/include/gtensor/complex_bfloat16_t.h b/include/gtensor/complex_bfloat16_t.h new file mode 100644 index 00000000..07f45292 --- /dev/null +++ b/include/gtensor/complex_bfloat16_t.h @@ -0,0 +1,386 @@ +#ifndef GTENSOR_COMPLEX_BFLOAT16T_H +#define GTENSOR_COMPLEX_BFLOAT16T_H + +#include + +#include "bfloat16_t.h" +#include "complex.h" +#include "macros.h" + +namespace gt +{ + +// ====================================================================== +// complex_bfloat16_t [clones complex_float16_t] +// ... adapted from the C++ header, +// see e.g., https://en.cppreference.com/w/cpp/header/complex [2023/10/17] + +class complex_bfloat16_t; + +// operators: +GT_INLINE complex_bfloat16_t operator+(const complex_bfloat16_t&, + const complex_bfloat16_t&); +GT_INLINE complex_bfloat16_t operator+(const complex_bfloat16_t&, + const bfloat16_t&); +GT_INLINE complex_bfloat16_t operator+(const bfloat16_t&, + const complex_bfloat16_t&); + +GT_INLINE complex_bfloat16_t operator-(const complex_bfloat16_t&, + const complex_bfloat16_t&); +GT_INLINE complex_bfloat16_t operator-(const complex_bfloat16_t&, + const bfloat16_t&); +GT_INLINE complex_bfloat16_t operator-(const bfloat16_t&, + const complex_bfloat16_t&); + +GT_INLINE complex_bfloat16_t operator*(const complex_bfloat16_t&, + const complex_bfloat16_t&); +GT_INLINE complex_bfloat16_t operator*(const complex_bfloat16_t&, + const bfloat16_t&); +GT_INLINE complex_bfloat16_t operator*(const bfloat16_t&, + const complex_bfloat16_t&); + +GT_INLINE complex_bfloat16_t operator/(const complex_bfloat16_t&, + const complex_bfloat16_t&); +GT_INLINE complex_bfloat16_t operator/(const complex_bfloat16_t&, + const bfloat16_t&); +GT_INLINE complex_bfloat16_t operator/(const bfloat16_t&, + const complex_bfloat16_t&); + +GT_INLINE complex_bfloat16_t operator+(const complex_bfloat16_t&); +GT_INLINE complex_bfloat16_t operator-(const complex_bfloat16_t&); + +GT_INLINE bool operator==(const complex_bfloat16_t&, const complex_bfloat16_t&); +GT_INLINE bool operator==(const complex_bfloat16_t&, const bfloat16_t&); +GT_INLINE bool operator==(const bfloat16_t&, const complex_bfloat16_t&); + +GT_INLINE bool operator!=(const complex_bfloat16_t&, const complex_bfloat16_t&); +GT_INLINE bool operator!=(const complex_bfloat16_t&, const bfloat16_t&); +GT_INLINE bool operator!=(const bfloat16_t&, const complex_bfloat16_t&); + +template +std::basic_istream& operator>>( + std::basic_istream&, complex_bfloat16_t&); + +template +std::basic_ostream& operator<<( + std::basic_ostream&, const complex_bfloat16_t&); + +// values: +GT_INLINE bfloat16_t real(const complex_bfloat16_t&); +GT_INLINE bfloat16_t imag(const complex_bfloat16_t&); + +GT_INLINE bfloat16_t abs(const complex_bfloat16_t&); +GT_INLINE bfloat16_t norm(const complex_bfloat16_t&); + +GT_INLINE complex_bfloat16_t conj(const complex_bfloat16_t&); + +// values = delete [NOT IMPLEMENTED] +bfloat16_t arg(const complex_bfloat16_t&) = delete; +complex_bfloat16_t proj(const complex_bfloat16_t&) = delete; +complex_bfloat16_t polar(const bfloat16_t&, const bfloat16_t& = 0) = delete; + +// transcendentals = delete [NOT IMPLEMENTED] +complex_bfloat16_t acos(const complex_bfloat16_t&) = delete; +complex_bfloat16_t asin(const complex_bfloat16_t&) = delete; +complex_bfloat16_t atan(const complex_bfloat16_t&) = delete; + +complex_bfloat16_t acosh(const complex_bfloat16_t&) = delete; +complex_bfloat16_t asinh(const complex_bfloat16_t&) = delete; +complex_bfloat16_t atanh(const complex_bfloat16_t&) = delete; + +complex_bfloat16_t cos(const complex_bfloat16_t&) = delete; +complex_bfloat16_t cosh(const complex_bfloat16_t&) = delete; +complex_bfloat16_t exp(const complex_bfloat16_t&) = delete; +complex_bfloat16_t log(const complex_bfloat16_t&) = delete; +complex_bfloat16_t log10(const complex_bfloat16_t&) = delete; + +complex_bfloat16_t pow(const complex_bfloat16_t&, const bfloat16_t&) = delete; +complex_bfloat16_t pow(const complex_bfloat16_t&, + const complex_bfloat16_t&) = delete; +complex_bfloat16_t pow(const bfloat16_t&, const complex_bfloat16_t&) = delete; + +complex_bfloat16_t sin(const complex_bfloat16_t&) = delete; +complex_bfloat16_t sinh(const complex_bfloat16_t&) = delete; +complex_bfloat16_t sqrt(const complex_bfloat16_t&) = delete; +complex_bfloat16_t tan(const complex_bfloat16_t&) = delete; +complex_bfloat16_t tanh(const complex_bfloat16_t&) = delete; + +class complex_bfloat16_t +{ +public: + typedef bfloat16_t value_type; + GT_INLINE complex_bfloat16_t(const bfloat16_t& re = bfloat16_t(), + const bfloat16_t& im = bfloat16_t()) + : _real(re), _imag(im) + {} + complex_bfloat16_t(const complex_bfloat16_t&) = default; + template + GT_INLINE explicit complex_bfloat16_t(const complex& z) + : _real(z.real()), _imag(z.imag()) + {} + + GT_INLINE bfloat16_t real() const { return _real; } + GT_INLINE void real(bfloat16_t re) { _real = re; } + GT_INLINE bfloat16_t imag() const { return _imag; } + GT_INLINE void imag(bfloat16_t im) { _imag = im; } + + GT_INLINE complex_bfloat16_t& operator=(const bfloat16_t& x) + { + _real = x; + _imag = 0; + return *this; + } + GT_INLINE complex_bfloat16_t& operator+=(const bfloat16_t& x) + { + _real += x; + return *this; + } + GT_INLINE complex_bfloat16_t& operator-=(const bfloat16_t& x) + { + _real -= x; + return *this; + } + GT_INLINE complex_bfloat16_t& operator*=(const bfloat16_t& x) + { + _real *= x; + _imag *= x; + return *this; + } + GT_INLINE complex_bfloat16_t& operator/=(const bfloat16_t& x) + { + _real /= x; + _imag /= x; + return *this; + } + + complex_bfloat16_t& operator=(const complex_bfloat16_t&) = default; + GT_INLINE complex_bfloat16_t& operator+=(const complex_bfloat16_t& z) + { + _real += z.real(); + _imag += z.imag(); + return *this; + } + GT_INLINE complex_bfloat16_t& operator-=(const complex_bfloat16_t& z) + { + _real -= z.real(); + _imag -= z.imag(); + return *this; + } + GT_INLINE complex_bfloat16_t& operator*=(const complex_bfloat16_t& z) + { + const auto retmp{_real}; + _real = _real * z.real() - _imag * z.imag(); + _imag = _imag * z.real() + retmp * z.imag(); + return *this; + } + GT_INLINE complex_bfloat16_t& operator/=(const complex_bfloat16_t& z) + { + auto z_alt = conj(z); + z_alt /= norm(z); + *this *= z_alt; + return *this; + } + + template + GT_INLINE complex_bfloat16_t& operator=(const complex& z) + { + _real = z.real(); + _imag = z.imag(); + return *this; + } + template + GT_INLINE complex_bfloat16_t& operator+=(const complex& z) + { + *this += complex_bfloat16_t{z}; + return *this; + } + template + GT_INLINE complex_bfloat16_t& operator-=(const complex& z) + { + *this -= complex_bfloat16_t{z}; + return *this; + } + template + GT_INLINE complex_bfloat16_t& operator*=(const complex& z) + { + *this *= complex_bfloat16_t{z}; + return *this; + } + template + GT_INLINE complex_bfloat16_t& operator/=(const complex& z) + { + *this /= complex_bfloat16_t{z}; + return *this; + } + +private: + bfloat16_t _real; + bfloat16_t _imag; +}; + +// operators: +GT_INLINE complex_bfloat16_t operator+(const complex_bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result += rhs; + return result; +} + +GT_INLINE complex_bfloat16_t operator+(const complex_bfloat16_t& lhs, + const bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result += rhs; + return result; +} +GT_INLINE complex_bfloat16_t operator+(const bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result += rhs; + return result; +} + +GT_INLINE complex_bfloat16_t operator-(const complex_bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result -= rhs; + return result; +} +GT_INLINE complex_bfloat16_t operator-(const complex_bfloat16_t& lhs, + const bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result -= rhs; + return result; +} +GT_INLINE complex_bfloat16_t operator-(const bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result -= rhs; + return result; +} + +GT_INLINE complex_bfloat16_t operator*(const complex_bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result *= rhs; + return result; +} +GT_INLINE complex_bfloat16_t operator*(const complex_bfloat16_t& lhs, + const bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result *= rhs; + return result; +} +GT_INLINE complex_bfloat16_t operator*(const bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result *= rhs; + return result; +} + +GT_INLINE complex_bfloat16_t operator/(const complex_bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result /= rhs; + return result; +} +GT_INLINE complex_bfloat16_t operator/(const complex_bfloat16_t& lhs, + const bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result /= rhs; + return result; +} +GT_INLINE complex_bfloat16_t operator/(const bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + complex_bfloat16_t result{lhs}; + result /= rhs; + return result; +} + +GT_INLINE complex_bfloat16_t operator+(const complex_bfloat16_t& z) +{ + return z; +} +GT_INLINE complex_bfloat16_t operator-(const complex_bfloat16_t& z) +{ + return complex_bfloat16_t{-z.real(), -z.imag()}; +} + +GT_INLINE bool operator==(const complex_bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + return lhs.real() == rhs.real() && lhs.imag() == rhs.imag(); +} +GT_INLINE bool operator==(const complex_bfloat16_t& lhs, const bfloat16_t& rhs) +{ + return lhs.real() == rhs && lhs.imag() == 0; +} +GT_INLINE bool operator==(const bfloat16_t& lhs, const complex_bfloat16_t& rhs) +{ + return lhs == rhs.real() && 0 == rhs.imag(); +} + +GT_INLINE bool operator!=(const complex_bfloat16_t& lhs, + const complex_bfloat16_t& rhs) +{ + return lhs.real() != rhs.real() || lhs.imag() != rhs.imag(); +} +GT_INLINE bool operator!=(const complex_bfloat16_t& lhs, const bfloat16_t& rhs) +{ + return lhs.real() != rhs || lhs.imag() != 0; +} +GT_INLINE bool operator!=(const bfloat16_t& lhs, const complex_bfloat16_t& rhs) +{ + return lhs != rhs.real() || 0 != rhs.imag(); +} + +template +std::basic_istream& operator>>( + std::basic_istream& s, complex_bfloat16_t& z) +{ + complex w; + s >> w; + z = w; + return s; +} + +template +std::basic_ostream& operator<<( + std::basic_ostream& s, const complex_bfloat16_t& z) +{ + return s << "(" << z.real() << ", " << z.imag() << ")"; +} + +// values: +GT_INLINE bfloat16_t real(const complex_bfloat16_t& z) { return z.real(); } +GT_INLINE bfloat16_t imag(const complex_bfloat16_t& z) { return z.imag(); } + +GT_INLINE bfloat16_t abs(const complex_bfloat16_t& z) +{ + auto abs2 = norm(z); + return sqrt(abs2); +} +GT_INLINE bfloat16_t norm(const complex_bfloat16_t& z) +{ + return z.real() * z.real() + z.imag() * z.imag(); +} + +GT_INLINE complex_bfloat16_t conj(const complex_bfloat16_t& z) +{ + return complex_bfloat16_t{z.real(), -z.imag()}; +} + +} // namespace gt + +#endif // GTENSOR_COMPLEX_BFLOAT16T_H diff --git a/include/gtensor/float16_t.h b/include/gtensor/float16_t.h index 826b83de..2b07f8e8 100644 --- a/include/gtensor/float16_t.h +++ b/include/gtensor/float16_t.h @@ -20,22 +20,23 @@ namespace gt // ====================================================================== // float16_t +class float16_t +{ + #if defined(GTENSOR_FP16_CUDA_HEADER) -using storage_type = __half; + using storage_type = __half; #else #error "GTENSOR_ENABLE_FP16=ON, but no 16-bit FP type available!" #endif #if defined(GTENSOR_FP16_CUDA_HEADER) && defined(__CUDA_ARCH__) && \ (__CUDA_ARCH__ >= 530) -using compute_type = __half; + using compute_type = __half; #define FLOAT16T_ON_CUDA_DEVICE #else -using compute_type = float; + using compute_type = float; #endif -class float16_t -{ public: float16_t() = default; GT_INLINE float16_t(float x) : x(x){}; diff --git a/include/gtensor/gtensor.h b/include/gtensor/gtensor.h index 865427e4..443dd583 100644 --- a/include/gtensor/gtensor.h +++ b/include/gtensor/gtensor.h @@ -24,6 +24,11 @@ #include "float16_t.h" #endif +#if defined(GTENSOR_ENABLE_BF16) +#include "bfloat16_t.h" +#include "complex_bfloat16_t.h" +#endif + namespace gt { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 943d455f..b23312a7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -85,3 +85,8 @@ if (GTENSOR_ENABLE_FP16) add_gtensor_test(test_float16_t) add_gtensor_test(test_complex_float16_t) endif() + +if (GTENSOR_ENABLE_BF16) + add_gtensor_test(test_bfloat16_t) + add_gtensor_test(test_complex_bfloat16_t) +endif() diff --git a/tests/test_bfloat16_t.cxx b/tests/test_bfloat16_t.cxx new file mode 100644 index 00000000..296b0c25 --- /dev/null +++ b/tests/test_bfloat16_t.cxx @@ -0,0 +1,335 @@ +#include + +#include + +#include + +TEST(bfloat16_t, scalar_arithmetic) +{ + gt::bfloat16_t a{1.0}; + gt::bfloat16_t b{2.0}; + + gt::bfloat16_t c{0.0}; + gt::bfloat16_t ref{0.0}; + + c = a + b; + ref = 3.0; + EXPECT_EQ(c, ref); + + c = a - b; + ref = -1.0; + EXPECT_EQ(c, ref); + + c = a * b; + ref = 2.0; + EXPECT_EQ(c, ref); + + c = a / b; + ref = 0.5; + EXPECT_EQ(c, ref); +} + +TEST(bfloat16_t, update_operators) +{ + gt::bfloat16_t a{1.0}; + gt::bfloat16_t b{2.0}; + + gt::bfloat16_t c{a}; + gt::bfloat16_t ref{0.0}; + + c += b; + ref = 3.0; + EXPECT_EQ(c, ref); + + c -= b; + ref = a; + EXPECT_EQ(c, ref); + + c *= b; + ref = 2.0; + EXPECT_EQ(c, ref); + + c /= b; + ref = 1.0; + EXPECT_EQ(c, ref); +} + +TEST(bfloat16_t, unary_operators) +{ + gt::bfloat16_t a{2.0}; + gt::bfloat16_t b{-2.0}; + + gt::bfloat16_t c{a}; + + c = +a; + EXPECT_EQ(c, a); + + c = -a; + EXPECT_EQ(c, b); +} + +TEST(bfloat16_t, binary_comparison_operators) +{ + gt::bfloat16_t a{1.0}; + gt::bfloat16_t b{2.0}; + gt::bfloat16_t c{2.0}; + int d{2}; + + EXPECT_EQ(a, a); + EXPECT_EQ(b, b); + EXPECT_EQ(b, c); + EXPECT_EQ(b, d); + EXPECT_EQ(c, b); + EXPECT_EQ(c, c); + EXPECT_EQ(c, d); + EXPECT_EQ(d, b); + EXPECT_EQ(d, c); + + EXPECT_NE(a, b); + EXPECT_NE(a, c); + EXPECT_NE(a, d); + EXPECT_NE(b, a); + EXPECT_NE(c, a); + EXPECT_NE(d, a); + + EXPECT_LT(a, b); + EXPECT_LT(a, c); + + EXPECT_LE(a, a); + EXPECT_LE(a, b); + EXPECT_LE(a, c); + EXPECT_LE(b, b); + EXPECT_LE(b, c); + EXPECT_LE(c, b); + EXPECT_LE(c, c); + + EXPECT_GT(b, a); + EXPECT_GT(c, a); + + EXPECT_GE(a, a); + EXPECT_GE(b, a); + EXPECT_GE(b, b); + EXPECT_GE(b, c); + EXPECT_GE(c, a); + EXPECT_GE(c, b); + EXPECT_GE(c, c); +} + +TEST(bfloat16_t, sqrt) +{ + gt::bfloat16_t a{4.0}; + gt::bfloat16_t b; + gt::bfloat16_t ref{2.0}; + + b = gt::sqrt(a); + EXPECT_EQ(b, ref); +} + +template +void generic_fill_1D(gt::gtensor& x, + const gt::bfloat16_t& fill_value) +{ + auto k_x = x.to_kernel(); + + gt::launch<1, S>( + x.shape(), GT_LAMBDA(int i) { k_x(i) = fill_value; }); +} + +TEST(bfloat16_t, auto_init_host) +{ + gt::bfloat16_t fill_value{1.25}; + gt::gtensor a(gt::shape(5), fill_value); + gt::gtensor b(a.shape()); + + generic_fill_1D(b, fill_value); + + EXPECT_EQ(a, b); +} + +TEST(bfloat16_t, auto_init_device) +{ + gt::bfloat16_t fill_value{1.25}; + gt::gtensor a(gt::shape(5), fill_value); + gt::gtensor b(a.shape()); + + generic_fill_1D(b, fill_value); + + EXPECT_EQ(a, b); +} + +template +void generic_explicit_haxpy_1D(const gt::bfloat16_t& a, + const gt::gtensor& x, + gt::gtensor& y) +{ + auto k_x = x.to_kernel(); + auto k_y = y.to_kernel(); + + gt::launch<1, S>( + y.shape(), GT_LAMBDA(int i) { k_y(i) = k_y(i) + a * k_x(i); }); +} + +TEST(bfloat16_t, haxpy_explicit_1D_host) +{ + gt::gtensor x(gt::shape(3), 1.5); + gt::gtensor y(x.shape(), 2.5); + gt::bfloat16_t a{0.5}; + gt::gtensor ref(x.shape(), 3.25); + + generic_explicit_haxpy_1D(a, x, y); + + EXPECT_EQ(y, ref); +} + +TEST(bfloat16_t, haxpy_explicit_1D_device) +{ + gt::gtensor x(gt::shape(3), 1.5); + gt::gtensor y(x.shape(), 2.5); + gt::bfloat16_t a{0.5}; + gt::gtensor ref(y.shape(), 3.25); + + generic_explicit_haxpy_1D(a, x, y); + + EXPECT_EQ(y, ref); +} + +TEST(bfloat16_t, haxpy_implicit_1D_host) +{ + gt::gtensor x(gt::shape(3), 1.5); + gt::gtensor y(x.shape(), 2.5); + gt::bfloat16_t a{0.5}; + gt::gtensor ref(x.shape(), 3.25); + + y = a * x + y; + + EXPECT_EQ(y, ref); +} + +TEST(bfloat16_t, haxpy_implicit_1D_device) +{ + gt::gtensor x(gt::shape(3), 1.5); + gt::gtensor y(x.shape(), 2.5); + gt::bfloat16_t a{0.5}; + gt::gtensor ref(y.shape(), 3.25); + + y = a * x + y; + + EXPECT_EQ(y, ref); +} + +template +void generic_explicit_custom_kernel_1D( + const gt::bfloat16_t& s1, const gt::bfloat16_t& s2, + const gt::gtensor& a, + const gt::gtensor& b, + const gt::gtensor& c, + const gt::gtensor& d, + const gt::gtensor& e, + gt::gtensor& result) +{ + auto k_a = a.to_kernel(); + auto k_b = b.to_kernel(); + auto k_c = c.to_kernel(); + auto k_d = d.to_kernel(); + auto k_e = e.to_kernel(); + auto k_r = result.to_kernel(); + + gt::launch<1, S>( + result.shape(), GT_LAMBDA(int i) { + k_r(i) = s2 - k_e(i) * ((k_a(i) - s1 * k_b(i)) / k_c(i) + k_d(i)); + }); +} + +TEST(bfloat16_t, custom_kernel_explicit_implicit_host_device) +{ + gt::bfloat16_t a_val{12.34}, b_val{2.345}, c_val{0.987}, d_val{0.67}, + e_val{3.14}; + gt::bfloat16_t s1{0.1}, s2{4.56}; + + gt::bfloat16_t r = s2 - e_val * ((a_val - s1 * b_val) / c_val + d_val); + + auto shape = gt::shape(3); + + gt::gtensor h_a(shape, a_val); + gt::gtensor h_b(shape, b_val); + gt::gtensor h_c(shape, c_val); + gt::gtensor h_d(shape, d_val); + gt::gtensor h_e(shape, e_val); + gt::gtensor h_r_expl(shape); + gt::gtensor h_r_impl(shape); + + gt::gtensor d_a(shape, a_val); + gt::gtensor d_b(shape, b_val); + gt::gtensor d_c(shape, c_val); + gt::gtensor d_d(shape, d_val); + gt::gtensor d_e(shape, e_val); + gt::gtensor d_r_expl(shape); + gt::gtensor d_r_impl(shape); + + h_r_impl = s2 - h_e * ((h_a - s1 * h_b) / h_c + h_d); + d_r_impl = s2 - d_e * ((d_a - s1 * d_b) / d_c + d_d); + + generic_explicit_custom_kernel_1D(s1, s2, h_a, h_b, h_c, h_d, + h_e, h_r_expl); + + generic_explicit_custom_kernel_1D(s1, s2, d_a, d_b, d_c, + d_d, d_e, d_r_expl); + + EXPECT_EQ(h_r_impl(2), r); + EXPECT_EQ(h_r_impl, h_r_expl); + EXPECT_EQ(h_r_impl, d_r_expl); + EXPECT_EQ(h_r_impl, d_r_impl); +} + +TEST(bfloat16_t, mixed_precision_scalar) +{ + gt::bfloat16_t a_16{1.0}; + + gt::bfloat16_t b_16{2.0}; + float b_32{2.0}; + double b_64{2.0}; + + auto c_16 = a_16 + b_16; + auto c_32 = a_16 + b_32; + auto c_64 = a_16 + b_64; + + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + EXPECT_TRUE((std::is_same::value)); + + EXPECT_EQ(c_16, c_32); + EXPECT_EQ(c_16, c_64); +} + +template +void test_mixed_precision_helper() +{ + auto shape = gt::shape(3); + gt::gtensor vh(shape, 4.0); + gt::gtensor vf(shape, 3.0); + gt::gtensor vd(shape, 2.0); + + gt::gtensor rh(shape); + gt::gtensor rf(shape); + gt::gtensor rd(shape); + + gt::gtensor ref(shape, 10.0); + + rh = (vh * vf) - (vh / vd); + rf = (vh * vf) - (vh / vd); + rd = (vh * vf) - (vh / vd); + + EXPECT_EQ(ref, rh); + EXPECT_EQ(ref, rf); + EXPECT_EQ(ref, rd); +} + +TEST(bfloat16_t, mixed_precision_host) +{ + test_mixed_precision_helper(); +} + +TEST(bfloat16_t, mixed_precision_device) +{ + test_mixed_precision_helper(); +} diff --git a/tests/test_complex_bfloat16_t.cxx b/tests/test_complex_bfloat16_t.cxx new file mode 100644 index 00000000..d8f14af0 --- /dev/null +++ b/tests/test_complex_bfloat16_t.cxx @@ -0,0 +1,526 @@ +#include + +#include + +#include +#include + +#include + +TEST(complex_bfloat16_t, comparison_operators) +{ + gt::complex_bfloat16_t a{7.0, -2.0}; + gt::complex_bfloat16_t b{6.0, -3.0}; + gt::complex_bfloat16_t c{7.0, -3.0}; + gt::complex_bfloat16_t d{6.0, -2.0}; + + EXPECT_EQ(a, a); + EXPECT_NE(a, b); + EXPECT_NE(a, c); + EXPECT_NE(a, d); + + gt::complex_bfloat16_t e{3.0, 0.0}; + gt::complex_bfloat16_t f{3.0, 1.0}; + gt::bfloat16_t s{3.0}; + gt::bfloat16_t t{4.0}; + + EXPECT_EQ(e, s); + EXPECT_EQ(s, e); + EXPECT_NE(f, s); + EXPECT_NE(s, f); + EXPECT_NE(e, t); + EXPECT_NE(t, e); + EXPECT_NE(f, t); + EXPECT_NE(t, f); +} + +TEST(complex_bfloat16_t, constructors) +{ + gt::complex_bfloat16_t a{7.0, -2.0}; + gt::complex_bfloat16_t b{a}; + gt::complex c{7.0, -2.0}; + gt::complex_bfloat16_t d{c}; + + EXPECT_EQ(a, b); + EXPECT_EQ(a, d); +} + +TEST(complex_bfloat16_t, assignment) +{ + gt::complex_bfloat16_t a{7.0, -2.0}; + gt::complex_bfloat16_t b{0.0, 0.0}; + gt::complex_bfloat16_t c{3.0, 0.0}; + gt::bfloat16_t x{3.0}; + gt::complex_bfloat16_t e{2.0, 1.0}; + gt::complex f{2.0, 1.0}; + + b = a; + EXPECT_EQ(a, b); + + b = x; + EXPECT_EQ(c, b); + + b = f; + EXPECT_EQ(e, b); +} + +TEST(complex_bfloat16_t, getter_setter) +{ + const gt::complex_bfloat16_t a{7.0, -2.0}; + gt::complex_bfloat16_t b; + + b.real(a.real()); + b.imag(a.imag()); + + EXPECT_EQ(a, b); +} + +TEST(complex_bfloat16_t, update_operators) +{ + gt::complex_bfloat16_t a{5.0, -3.0}; + gt::complex_bfloat16_t b{-2.0, 2.0}; + gt::complex f{-2.0, 2.0}; + gt::complex_bfloat16_t ref; + gt::bfloat16_t x{3.0}; + + a += b; + ref = gt::complex_bfloat16_t{3.0, -1.0}; + EXPECT_EQ(a, ref); + + a += x; + ref = gt::complex_bfloat16_t{6.0, -1.0}; + EXPECT_EQ(a, ref); + + a -= b; + ref = gt::complex_bfloat16_t{8.0, -3.0}; + EXPECT_EQ(a, ref); + + a -= x; + ref = gt::complex_bfloat16_t{5.0, -3.0}; + EXPECT_EQ(a, ref); + + a *= b; + ref = gt::complex_bfloat16_t{-4.0, 16.0}; + EXPECT_EQ(a, ref); + + a *= x; + ref = gt::complex_bfloat16_t{-12.0, 48.0}; + EXPECT_EQ(a, ref); + + a /= x; + ref = gt::complex_bfloat16_t{-4.0, 16.0}; + EXPECT_EQ(a, ref); + + a /= b; + ref = gt::complex_bfloat16_t{5.0, -3.0}; + EXPECT_EQ(a, ref); // exact because b chosen s.t. norm(b) = 8 + + a += f; + ref = gt::complex_bfloat16_t{3.0, -1.0}; + EXPECT_EQ(a, ref); + + a += x; + a -= f; + ref = gt::complex_bfloat16_t{8.0, -3.0}; + EXPECT_EQ(a, ref); + + a -= x; + a *= f; + ref = gt::complex_bfloat16_t{-4.0, 16.0}; + EXPECT_EQ(a, ref); + + a /= f; + ref = gt::complex_bfloat16_t{5.0, -3.0}; + EXPECT_EQ(a, ref); // exact because f chosen s.t. norm(b) = 8 +} + +TEST(complex_bfloat16_t, values) +{ + gt::complex_bfloat16_t a{4.0, -3.0}; + + gt::bfloat16_t a_real{4.0}; + gt::bfloat16_t a_imag{-3.0}; + gt::bfloat16_t a_abs{5.0}; + gt::bfloat16_t a_norm{25.0}; + gt::complex_bfloat16_t a_conj{4.0, +3.0}; + + EXPECT_EQ(a_real, real(a)); + EXPECT_EQ(a_imag, imag(a)); + EXPECT_EQ(a_abs, abs(a)); + EXPECT_EQ(a_norm, norm(a)); + EXPECT_EQ(a_conj, conj(a)); +} + +TEST(complex_bfloat16_t, binary_arithmetic_operators) +{ + gt::complex_bfloat16_t a{4.0, -4.0}; + gt::complex_bfloat16_t b{-2.0, 2.0}; + gt::bfloat16_t x{8.0}; + gt::complex_bfloat16_t c; + gt::complex_bfloat16_t ref; + + c = a + b; + ref = gt::complex_bfloat16_t{2.0, -2.0}; + EXPECT_EQ(c, ref); + c = a + x; + ref = gt::complex_bfloat16_t{12.0, -4.0}; + EXPECT_EQ(c, ref); + c = x + a; + EXPECT_EQ(c, ref); + + c = a - b; + ref = gt::complex_bfloat16_t{6.0, -6.0}; + EXPECT_EQ(c, ref); + c = a - x; + ref = gt::complex_bfloat16_t{-4.0, -4.0}; + EXPECT_EQ(c, ref); + c = x - a; + ref = gt::complex_bfloat16_t{4.0, 4.0}; + EXPECT_EQ(c, ref); + + c = a * b; + ref = gt::complex_bfloat16_t{0.0, 16.0}; + EXPECT_EQ(c, ref); + c = a * x; + ref = gt::complex_bfloat16_t{32.0, -32.0}; + EXPECT_EQ(c, ref); + c = x * a; + EXPECT_EQ(c, ref); + + c = a / b; + ref = gt::complex_bfloat16_t{-2.0, 0.0}; + EXPECT_EQ(c, ref); // exact because b chosen s.t. norm(b) = 8 + c = a / x; + ref = gt::complex_bfloat16_t{0.5, -0.5}; + EXPECT_EQ(c, ref); + ref = gt::complex_bfloat16_t{1.0, 1.0}; + c = x / a; + EXPECT_EQ(c, ref); // exact because a chosen s.t. norm(a) = 32 +} + +TEST(complex_bfloat16_t, unary_arithmetic_operators) +{ + gt::complex_bfloat16_t a{4.0, -5.0}; + gt::complex_bfloat16_t b{-4.0, 5.0}; + gt::complex_bfloat16_t c; + + c = +a; + EXPECT_EQ(c, a); + + c = -a; + EXPECT_EQ(c, b); +} + +TEST(complex_bfloat16_t, iostream) +{ + std::istringstream is("(1.125, -2.5)"); + std::ostringstream os; + gt::complex_bfloat16_t a{1.125, -2.5}; + gt::complex_bfloat16_t b; + + is >> b; + EXPECT_EQ(a, b); + + os << a; + EXPECT_EQ(is.str(), os.str()); +} + +#ifdef GTENSOR_HAVE_DEVICE + +TEST(complex_bfloat16_t, device_complex_ops) +{ + using T = gt::complex_bfloat16_t; + gt::gtensor h_a(2); + gt::gtensor h_b(h_a.shape()); + gt::gtensor h_c(h_a.shape()); + gt::gtensor c(h_a.shape()); + gt::gtensor_device d_a(h_a.shape()); + gt::gtensor_device d_b(h_b.shape()); + gt::gtensor_device d_c(h_c.shape()); + + h_a(0) = T{7., -2.}; + h_a(1) = T{1., 4.}; + h_b(0) = T{7., 2.}; + h_b(1) = T{1., -4.}; + + gt::copy(h_a, d_a); + gt::copy(h_b, d_b); + + d_c = d_a + d_b; + gt::copy(d_c, h_c); + c(0) = T{14., 0.}; + c(1) = T{2., 0.}; + EXPECT_EQ(h_c, c); + + d_c = d_a - d_b; + gt::copy(d_c, h_c); + c(0) = T{0., -4.}; + c(1) = T{0., 8.}; + EXPECT_EQ(h_c, c); + + d_c = d_a * d_b; + gt::copy(d_c, h_c); + c(0) = T{53., 0.}; + c(1) = T{17., 0.}; + EXPECT_EQ(h_c, c); +} + +// compare against device_comlex_multiply test case with nvprof +TEST(complex_bfloat16_t, device_bfloat16_t_multiply) +{ + using T = gt::bfloat16_t; + gt::gtensor h_a(gt::shape(3, 2)); + gt::gtensor h_c(h_a.shape()); + gt::gtensor h_r(h_a.shape()); + + gt::gtensor_device a(h_a.shape()); + gt::gtensor_device c(h_a.shape()); + + // {{11., 12., 13.}, {21., 22., 23.}}; + h_a(0, 0) = T{11.}; + h_a(1, 0) = T{12.}; + h_a(2, 0) = T{13.}; + h_a(0, 1) = T{21.}; + h_a(1, 1) = T{22.}; + h_a(2, 1) = T{23.}; + + h_r(0, 0) = T{22.}; + h_r(1, 0) = T{24.}; + h_r(2, 0) = T{26.}; + h_r(0, 1) = T{42.}; + h_r(1, 1) = T{44.}; + h_r(2, 1) = T{46.}; + + gt::copy(h_a, a); + + auto Ifn = gt::scalar(2.0); + + auto e = Ifn * a; + std::cout << "e type: " << typeid(e).name() << " [kernel " + << typeid(e.to_kernel()).name() << "]\n"; + c = e; + std::cout << "c type: " << typeid(c).name() << " [kernel " + << typeid(c.to_kernel()).name() << "]" << std::endl; + + gt::copy(c, h_c); + + EXPECT_EQ(h_c, h_r); +} + +// Note: can be run with nvprof / nsys profile to see if thrust kernels +// are called unnecessarily (other than __unititialized_fill which is +// difficult to avoid without ugly hacks). +TEST(complex_bfloat16_t, device_complex_multiply) +{ + using T = gt::complex_bfloat16_t; + auto I = T{0., 1.0}; + gt::gtensor h_a(gt::shape(3, 2)); + gt::gtensor h_r(h_a.shape()); + gt::gtensor h_c(h_a.shape()); + + gt::gtensor_device a(h_a.shape()); + gt::gtensor_device c(h_a.shape()); + + // {{11., 12., 13.}, {21., 22., 23.}}; + h_a(0, 0) = T{11., 0}; + h_a(1, 0) = T{12., 0}; + h_a(2, 0) = T{13., 0}; + h_a(0, 1) = T{21., 0}; + h_a(1, 1) = T{22., 0}; + h_a(2, 1) = T{23., 0}; + + h_r(0, 0) = T{0., 11.}; + h_r(1, 0) = T{0., 12.}; + h_r(2, 0) = T{0., 13.}; + h_r(0, 1) = T{0., 21.}; + h_r(1, 1) = T{0., 22.}; + h_r(2, 1) = T{0., 23.}; + + gt::copy(h_a, a); + + auto e = I * a; + std::cout << "e type: " << typeid(e).name() << " [kernel " + << typeid(e.to_kernel()).name() << "]\n"; + c = e; + std::cout << "c type: " << typeid(c).name() << " [kernel " + << typeid(c.to_kernel()).name() << "]" << std::endl; + + gt::copy(c, h_c); + + EXPECT_EQ(h_c, h_r); +} + +// Note: can be run with nvprof / nsys profile to see if thrust kernels +// are called unnecessarily (other than __unititialized_fill which is +// difficult to avoid without ugly hacks). +TEST(complex_bfloat16_t, device_eval) +{ + using T = gt::complex_bfloat16_t; + auto I = T{0., 1.0}; + gt::gtensor h_a(gt::shape(3, 2)); + gt::gtensor h_b(h_a.shape()); + gt::gtensor h_c(h_a.shape()); + + gt::gtensor_device a(h_a.shape()); + gt::gtensor_device b(h_b.shape()); + + // {{11., 12., 13.}, {21., 22., 23.}}; + h_a(0, 0) = T{11., 0}; + h_a(1, 0) = T{12., 0}; + h_a(2, 0) = T{13., 0}; + h_a(0, 1) = T{21., 0}; + h_a(1, 1) = T{22., 0}; + h_a(2, 1) = T{23., 0}; + + h_b(0, 0) = T{-11., 0}; + h_b(1, 0) = T{-12., 0}; + h_b(2, 0) = T{-13., 0}; + h_b(0, 1) = T{-21., 0}; + h_b(1, 1) = T{-22., 0}; + h_b(2, 1) = T{-23., 0}; + + gt::copy(h_a, a); + gt::copy(h_b, b); + + auto e1 = a + I * b; + std::cout << "e1 type: " << typeid(e1).name() << "\n"; + auto e2 = a + I * a; + std::cout << "e2 type: " << typeid(e2).name() << "\n"; + auto e = T{1. / 2.} * (e1 + e2); + std::cout << "e type: " << typeid(e).name() << "\n"; + auto c = eval(e); + std::cout << "c type: " << typeid(c).name() << std::endl; + + gt::copy(c, h_c); + + EXPECT_EQ(h_c, h_a); +} + +#if defined(GTENSOR_DEVICE_CUDA) || defined(GTENSOR_DEVICE_HIP) + +__global__ void kernel_norm( + gt::gtensor_span_device d_in, + gt::gtensor_span_device d_out) +{ + int i = threadIdx.x; + if (i < d_in.shape(0)) { + d_out(i) = gt::norm(d_in(i)); + } +} + +__global__ void kernel_conj( + gt::gtensor_span_device d_in, + gt::gtensor_span_device d_out) +{ + int i = threadIdx.x; + if (i < d_in.shape(0)) { + d_out(i) = gt::conj(d_in(i)); + } +} + +TEST(complex_bfloat16_t, device_norm) +{ + const int N = 6; + using T = gt::complex_bfloat16_t; + auto I = T{0., 1.0}; + gt::gtensor h_a(gt::shape(N)); + gt::gtensor h_norm(h_a.shape()); + + gt::gtensor_device d_a(h_a.shape()); + gt::gtensor_device d_norm(d_a.shape()); + + for (int i = 0; i < N; i++) { + h_a(i) = T{1., static_cast(i)}; + } + + gt::copy(h_a, d_a); + + gtLaunchKernel(kernel_norm, 1, N, 0, 0, d_a.to_kernel(), d_norm.to_kernel()); + + gt::copy(d_norm, h_norm); + + for (int i = 0; i < N; i++) { + EXPECT_EQ(h_norm(i), gt::norm(h_a(i))); + } +} + +TEST(complex_bfloat16_t, device_conj) +{ + const int N = 6; + using T = gt::complex_bfloat16_t; + auto I = T{0., 1.0}; + gt::gtensor h_a(gt::shape(N)); + gt::gtensor h_conj(h_a.shape()); + + gt::gtensor_device d_a(h_a.shape()); + gt::gtensor_device d_conj(d_a.shape()); + + for (int i = 0; i < N; i++) { + h_a(i) = T{1., static_cast(i)}; + } + + gt::copy(h_a, d_a); + + gtLaunchKernel(kernel_conj, 1, N, 0, 0, d_a.to_kernel(), d_conj.to_kernel()); + + gt::copy(d_conj, h_conj); + + for (int i = 0; i < N; i++) { + EXPECT_EQ(h_conj(i), gt::conj(h_a(i))); + } +} + +template +static void run_device_abs(gt::gtensor_device& res, + const gt::gtensor_device& x) +{ + auto k_res = res.to_kernel(); + auto k_x = x.to_kernel(); + + gt::launch<1>( + x.shape(), GT_LAMBDA(int i) { k_res(i) = gt::abs(k_x(i)); }); + gt::synchronize(); +} + +TEST(complex_bfloat16_t, device_abs_real) +{ + using T = gt::bfloat16_t; + + gt::gtensor h_x = {-1.75, -0.001}; + gt::gtensor_device x{h_x.shape()}; + + gt::copy(h_x, x); + + auto res = gt::empty_like(x); + run_device_abs(res, x); + + gt::gtensor h_res(res.shape()); + gt::copy(res, h_res); + gt::synchronize(); + + EXPECT_EQ(h_res(0), gt::abs(h_x(0))); + EXPECT_EQ(h_res(1), gt::abs(h_x(1))); +} + +TEST(complex_bfloat16_t, device_abs) +{ + using R = gt::bfloat16_t; + using T = gt::complex_bfloat16_t; + + gt::gtensor_device x(gt::shape(1)); + gt::gtensor h_x(x.shape()); + h_x(0) = T(sqrt(2.) / 2., sqrt(2.) / 2.); + gt::copy(h_x, x); + + gt::gtensor_device res(x.shape()); + run_device_abs(res, x); + + gt::gtensor h_res(res.shape()); + gt::copy(res, h_res); + // here, truncation and rounding errors cancel for bfloat16 + EXPECT_EQ(h_res(0), R(1)); +} + +#endif // CUDA or HIP + +#endif // GTENSOR_HAVE_DEVICE