Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Types gt::bfloat16_t and gt::complex_bfloat16_t #283

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ option(GTENSOR_BOUNDS_CHECK "Enable per access bounds checking" OFF)
option(GTENSOR_ADDRESS_CHECK "Enable address checking for device spans" OFF)
option(GTENSOR_SYNC_KERNELS "Enable host sync after assign and launch kernels" OFF)
option(GTENSOR_ENABLE_FP16 "Enable 16-bit floating point type gt::float16_t" OFF)
option(GTENSOR_ENABLE_BF16 "Enable 16-bit floating point type gt::bfloat16_t" OFF)

if (GTENSOR_ENABLE_FORTRAN)
# do this early (here) since later the `enable_language(Fortran)` gives me trouble
Expand Down Expand Up @@ -343,6 +344,13 @@ if (GTENSOR_ENABLE_FP16)
INTERFACE GTENSOR_ENABLE_FP16)
endif()

if (GTENSOR_ENABLE_BF16)
message(STATUS "${PROJECT_NAME}: gt::bfloat16_t is ENABLED")
message(STATUS "${PROJECT_NAME}: gt::complex_bfloat16_t is ENABLED")
target_compile_definitions(gtensor_${GTENSOR_DEVICE}
INTERFACE GTENSOR_ENABLE_BF16)
endif()


target_compile_definitions(gtensor_${GTENSOR_DEVICE} INTERFACE
GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT=${GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT})
Expand Down
227 changes: 227 additions & 0 deletions include/gtensor/bfloat16_t.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#ifndef GTENSOR_BFLOAT16T_H
#define GTENSOR_BFLOAT16T_H

#include <cmath>
#include <iostream>

#if __has_include(<cuda_bf16.h>)
#include <cuda_bf16.h>
#define GTENSOR_BF16_CUDA_HEADER
#elif 0 // TODO check if other bf16 type available
#else
#error "GTENSOR_ENABLE_BF16=ON, but no bfloat16 type available!"
#endif

#include "macros.h"

namespace gt
{

// ======================================================================
// bfloat16_t

class bfloat16_t
{

#if defined(GTENSOR_BF16_CUDA_HEADER)
using storage_type = __nv_bfloat16;
#else
#error "GTENSOR_ENABLE_BF16=ON, but no bfloat16 type available!"
#endif

#if defined(GTENSOR_BF16_CUDA_HEADER) && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 800)
using compute_type = __nv_bfloat16;
#define BFLOAT16T_ON_CUDA_DEVICE
#else
using compute_type = float;
#endif

public:
bfloat16_t() = default;
GT_INLINE bfloat16_t(float x) : x(x){};
GT_INLINE bfloat16_t(storage_type x) : x(x){};

GT_INLINE const bfloat16_t& operator=(const float f)
{
x = f;
return *this;
}
GT_INLINE compute_type Get() const { return static_cast<compute_type>(x); }

// update operators [+=, -=, *=, /=]
GT_INLINE bfloat16_t operator+=(const bfloat16_t& y)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
x += y.Get();
#else
x = this->Get() + y.Get();
#endif
return *this;
}
GT_INLINE bfloat16_t operator-=(const bfloat16_t& y)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
x -= y.Get();
#else
x = this->Get() - y.Get();
#endif
return *this;
}
GT_INLINE bfloat16_t operator*=(const bfloat16_t& y)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
x *= y.Get();
#else
x = this->Get() * y.Get();
#endif
return *this;
}
GT_INLINE bfloat16_t operator/=(const bfloat16_t& y)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
x /= y.Get();
#else
x = this->Get() / y.Get();
#endif
return *this;
}

private:
storage_type x;
Comment on lines +90 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally use an underscore suffix for private member variables, so x_ here would be appreciated, as it serves as a hint that something is a member variable when it's used within class member functions, and it arguably makes x(x) in the constructor a bit less confusing.

};

// op is unary [+, -]
#define PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(op) \
GT_INLINE bfloat16_t operator op(const bfloat16_t& rhs) \
{ \
return bfloat16_t(op rhs.Get()); \
}

PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(+);
PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(-);

// op is binary [+, -, *, /]
#define PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(op) \
GT_INLINE bfloat16_t operator op(const bfloat16_t& lhs, \
const bfloat16_t& rhs) \
{ \
return bfloat16_t(lhs.Get() op rhs.Get()); \
}

PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+);
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-);
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*);
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/);

// op is binary [+, -, *, /]
// fp_type is [float, double]
#define PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(op, fp_type) \
\
GT_INLINE fp_type operator op(const bfloat16_t& lhs, const fp_type& rhs) \
{ \
return static_cast<fp_type>(lhs.Get()) op rhs; \
} \
\
GT_INLINE fp_type operator op(const fp_type& lhs, const bfloat16_t& rhs) \
{ \
return lhs op static_cast<fp_type>(rhs.Get()); \
}

PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, float);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, float);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, float);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, float);

PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, double);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, double);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, double);
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, double);

// op is binary [==, !=, <, <=, >, >=]
#define PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(op) \
GT_INLINE bool operator op(const bfloat16_t& lhs, const bfloat16_t& rhs) \
{ \
return lhs.Get() op rhs.Get(); \
}

// op is binary [==, !=, <, <=, >, >=]
// fp_type is [float, double]
#define PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(op, fp_type) \
\
GT_INLINE bool operator op(const bfloat16_t& lhs, const fp_type& rhs) \
{ \
return static_cast<fp_type>(lhs.Get()) op rhs; \
} \
\
GT_INLINE bool operator op(const fp_type& lhs, const bfloat16_t& rhs) \
{ \
return lhs op static_cast<fp_type>(rhs.Get()); \
}

PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(==);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(!=);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(<);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(<=);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(>);
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(>=);

PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(==, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(!=, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<=, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>, float);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>=, float);

PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(==, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(!=, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<=, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>, double);
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>=, double);

// op is [==, !=]
// int_type is [int]
#define PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(op, int_type) \
\
GT_INLINE bool operator op(const bfloat16_t& lhs, const int_type& rhs) \
{ \
return lhs op static_cast<float>(rhs); \
} \
\
GT_INLINE bool operator op(const int_type& lhs, const bfloat16_t& rhs) \
{ \
return static_cast<float>(lhs) op rhs; \
}

PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(==, int);
PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(!=, int);

// function is sqrt
GT_INLINE bfloat16_t sqrt(const bfloat16_t& x)
{
#if defined(BFLOAT16T_ON_CUDA_DEVICE)
return hsqrt(x.Get());
#else
return std::sqrt(x.Get());
#endif
}

std::ostream& operator<<(std::ostream& s, const bfloat16_t& h)
{
s << static_cast<float>(h.Get());
return s;
}

} // namespace gt

#undef GTENSOR_BF16_CUDA_HEADER
#undef BFLOAT16T_ON_CUDA_DEVICE
#undef PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR
#undef PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR
#undef PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR
#undef PROVIDE_BFLOAT16T_COMPARISON_OPERATOR
#undef PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR
#undef PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR

#endif // GTENSOR_BFLOAT16T_H
Loading