-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Types gt::bfloat16_t and gt::complex_bfloat16_t #283
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
2d6b8e3
[fix] Make type aliases local to float16_t
cmpfeil 7c951d3
[feat/test/build] Add gt::[complex_]bfloat16_t
cmpfeil a377008
[clean] Apply clang-format
cmpfeil 399f0af
[clean] Apply more clang-format
cmpfeil d0cfe6a
[clean] Apply more clang-format
cmpfeil File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
#ifndef GTENSOR_BFLOAT16T_H | ||
#define GTENSOR_BFLOAT16T_H | ||
|
||
#include <cmath> | ||
#include <iostream> | ||
|
||
#if __has_include(<cuda_bf16.h>) | ||
#include <cuda_bf16.h> | ||
#define GTENSOR_BF16_CUDA_HEADER | ||
#elif 0 // TODO check if other bf16 type available | ||
#else | ||
#error "GTENSOR_ENABLE_BF16=ON, but no bfloat16 type available!" | ||
#endif | ||
|
||
#include "macros.h" | ||
|
||
namespace gt | ||
{ | ||
|
||
// ====================================================================== | ||
// bfloat16_t | ||
|
||
class bfloat16_t | ||
{ | ||
|
||
#if defined(GTENSOR_BF16_CUDA_HEADER) | ||
using storage_type = __nv_bfloat16; | ||
#else | ||
#error "GTENSOR_ENABLE_BF16=ON, but no bfloat16 type available!" | ||
#endif | ||
|
||
#if defined(GTENSOR_BF16_CUDA_HEADER) && defined(__CUDA_ARCH__) && \ | ||
(__CUDA_ARCH__ >= 800) | ||
using compute_type = __nv_bfloat16; | ||
#define BFLOAT16T_ON_CUDA_DEVICE | ||
#else | ||
using compute_type = float; | ||
#endif | ||
|
||
public: | ||
bfloat16_t() = default; | ||
GT_INLINE bfloat16_t(float x) : x(x){}; | ||
GT_INLINE bfloat16_t(storage_type x) : x(x){}; | ||
|
||
GT_INLINE const bfloat16_t& operator=(const float f) | ||
{ | ||
x = f; | ||
return *this; | ||
} | ||
GT_INLINE compute_type Get() const { return static_cast<compute_type>(x); } | ||
|
||
// update operators [+=, -=, *=, /=] | ||
GT_INLINE bfloat16_t operator+=(const bfloat16_t& y) | ||
{ | ||
#if defined(BFLOAT16T_ON_CUDA_DEVICE) | ||
x += y.Get(); | ||
#else | ||
x = this->Get() + y.Get(); | ||
#endif | ||
return *this; | ||
} | ||
GT_INLINE bfloat16_t operator-=(const bfloat16_t& y) | ||
{ | ||
#if defined(BFLOAT16T_ON_CUDA_DEVICE) | ||
x -= y.Get(); | ||
#else | ||
x = this->Get() - y.Get(); | ||
#endif | ||
return *this; | ||
} | ||
GT_INLINE bfloat16_t operator*=(const bfloat16_t& y) | ||
{ | ||
#if defined(BFLOAT16T_ON_CUDA_DEVICE) | ||
x *= y.Get(); | ||
#else | ||
x = this->Get() * y.Get(); | ||
#endif | ||
return *this; | ||
} | ||
GT_INLINE bfloat16_t operator/=(const bfloat16_t& y) | ||
{ | ||
#if defined(BFLOAT16T_ON_CUDA_DEVICE) | ||
x /= y.Get(); | ||
#else | ||
x = this->Get() / y.Get(); | ||
#endif | ||
return *this; | ||
} | ||
|
||
private: | ||
storage_type x; | ||
}; | ||
|
||
// op is unary [+, -] | ||
#define PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(op) \ | ||
GT_INLINE bfloat16_t operator op(const bfloat16_t& rhs) \ | ||
{ \ | ||
return bfloat16_t(op rhs.Get()); \ | ||
} | ||
|
||
PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(+); | ||
PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR(-); | ||
|
||
// op is binary [+, -, *, /] | ||
#define PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(op) \ | ||
GT_INLINE bfloat16_t operator op(const bfloat16_t& lhs, \ | ||
const bfloat16_t& rhs) \ | ||
{ \ | ||
return bfloat16_t(lhs.Get() op rhs.Get()); \ | ||
} | ||
|
||
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+); | ||
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-); | ||
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*); | ||
PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/); | ||
|
||
// op is binary [+, -, *, /] | ||
// fp_type is [float, double] | ||
#define PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(op, fp_type) \ | ||
\ | ||
GT_INLINE fp_type operator op(const bfloat16_t& lhs, const fp_type& rhs) \ | ||
{ \ | ||
return static_cast<fp_type>(lhs.Get()) op rhs; \ | ||
} \ | ||
\ | ||
GT_INLINE fp_type operator op(const fp_type& lhs, const bfloat16_t& rhs) \ | ||
{ \ | ||
return lhs op static_cast<fp_type>(rhs.Get()); \ | ||
} | ||
|
||
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, float); | ||
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, float); | ||
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, float); | ||
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, float); | ||
|
||
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, double); | ||
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, double); | ||
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, double); | ||
PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, double); | ||
|
||
// op is binary [==, !=, <, <=, >, >=] | ||
#define PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(op) \ | ||
GT_INLINE bool operator op(const bfloat16_t& lhs, const bfloat16_t& rhs) \ | ||
{ \ | ||
return lhs.Get() op rhs.Get(); \ | ||
} | ||
|
||
// op is binary [==, !=, <, <=, >, >=] | ||
// fp_type is [float, double] | ||
#define PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(op, fp_type) \ | ||
\ | ||
GT_INLINE bool operator op(const bfloat16_t& lhs, const fp_type& rhs) \ | ||
{ \ | ||
return static_cast<fp_type>(lhs.Get()) op rhs; \ | ||
} \ | ||
\ | ||
GT_INLINE bool operator op(const fp_type& lhs, const bfloat16_t& rhs) \ | ||
{ \ | ||
return lhs op static_cast<fp_type>(rhs.Get()); \ | ||
} | ||
|
||
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(==); | ||
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(!=); | ||
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(<); | ||
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(<=); | ||
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(>); | ||
PROVIDE_BFLOAT16T_COMPARISON_OPERATOR(>=); | ||
|
||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(==, float); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(!=, float); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<, float); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<=, float); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>, float); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>=, float); | ||
|
||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(==, double); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(!=, double); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<, double); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(<=, double); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>, double); | ||
PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR(>=, double); | ||
|
||
// op is [==, !=] | ||
// int_type is [int] | ||
#define PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(op, int_type) \ | ||
\ | ||
GT_INLINE bool operator op(const bfloat16_t& lhs, const int_type& rhs) \ | ||
{ \ | ||
return lhs op static_cast<float>(rhs); \ | ||
} \ | ||
\ | ||
GT_INLINE bool operator op(const int_type& lhs, const bfloat16_t& rhs) \ | ||
{ \ | ||
return static_cast<float>(lhs) op rhs; \ | ||
} | ||
|
||
PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(==, int); | ||
PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR(!=, int); | ||
|
||
// function is sqrt | ||
GT_INLINE bfloat16_t sqrt(const bfloat16_t& x) | ||
{ | ||
#if defined(BFLOAT16T_ON_CUDA_DEVICE) | ||
return hsqrt(x.Get()); | ||
#else | ||
return std::sqrt(x.Get()); | ||
#endif | ||
} | ||
|
||
std::ostream& operator<<(std::ostream& s, const bfloat16_t& h) | ||
{ | ||
s << static_cast<float>(h.Get()); | ||
return s; | ||
} | ||
|
||
} // namespace gt | ||
|
||
#undef GTENSOR_BF16_CUDA_HEADER | ||
#undef BFLOAT16T_ON_CUDA_DEVICE | ||
#undef PROVIDE_BFLOAT16T_UNARY_ARITHMETIC_OPERATOR | ||
#undef PROVIDE_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR | ||
#undef PROVIDE_MIXED_BFLOAT16T_BINARY_ARITHMETIC_OPERATOR | ||
#undef PROVIDE_BFLOAT16T_COMPARISON_OPERATOR | ||
#undef PROVIDE_MIXED_BFLOAT16T_COMPARISON_OPERATOR | ||
#undef PROVIDE_MIXED_INTEGRAL_BFLOAT16T_COMPARISON_OPERATOR | ||
|
||
#endif // GTENSOR_BFLOAT16T_H |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We generally use an underscore suffix for private member variables, so
x_
here would be appreciated, as it serves as a hint that something is a member variable when it's used within class member functions, and it arguably makesx(x)
in the constructor a bit less confusing.