Skip to content

Commit

Permalink
Merge pull request #276 from cmpfeil/gt_half_cuda
Browse files Browse the repository at this point in the history
Half precision type gt::float16_t
  • Loading branch information
germasch authored Sep 13, 2023
2 parents 5a1e34f + 88f100d commit 20d2ce5
Show file tree
Hide file tree
Showing 5 changed files with 431 additions and 0 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ option(GTENSOR_ALLOCATOR_CACHING "Enable naive caching allocators" ON)
option(GTENSOR_BOUNDS_CHECK "Enable per access bounds checking" OFF)
option(GTENSOR_ADDRESS_CHECK "Enable address checking for device spans" OFF)
option(GTENSOR_SYNC_KERNELS "Enable host sync after assign and launch kernels" OFF)
option(GTENSOR_ENABLE_FP16 "Enable 16-bit floating point type gt::float16_t" OFF)

if (GTENSOR_ENABLE_FORTRAN)
# do this early (here) since later the `enable_language(Fortran)` gives me trouble
Expand Down Expand Up @@ -335,6 +336,13 @@ else()
message(STATUS "${PROJECT_NAME}: sync kernels is OFF")
endif()

if (GTENSOR_ENABLE_FP16)
message(STATUS "${PROJECT_NAME}: gt::float16_t is ENABLED")
target_compile_definitions(gtensor_${GTENSOR_DEVICE}
INTERFACE GTENSOR_ENABLE_FP16)
endif()


target_compile_definitions(gtensor_${GTENSOR_DEVICE} INTERFACE
GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT=${GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT})
message(STATUS "${PROJECT_NAME}: default managed memory type '${GTENSOR_MANAGED_MEMORY_TYPE_DEFAULT}'")
Expand Down
136 changes: 136 additions & 0 deletions include/gtensor/float16_t.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#ifndef GTENSOR_FLOAT16T_H
#define GTENSOR_FLOAT16T_H

#include <iostream>

#if __has_include(<cuda_fp16.h>)
#include <cuda_fp16.h>
#define GTENSOR_FP16_CUDA_HEADER
#elif 0 // TODO check if other fp16 type available, e.g., _Float16
#else
#error "GTENSOR_ENABLE_FP16=ON, but no 16-bit FP type available!"
#endif

namespace gt
{

// ======================================================================
// float16_t

#if defined(GTENSOR_FP16_CUDA_HEADER)
using storage_type = __half;
#else
#error "GTENSOR_ENABLE_FP16=ON, but no 16-bit FP type available!"
#endif

#if defined(GTENSOR_FP16_CUDA_HEADER) && defined(__CUDA_ARCH__) && \
(__CUDA_ARCH__ >= 530)
using compute_type = __half;
#else
using compute_type = float;
#endif

class float16_t
{
public:
float16_t() = default;
GT_INLINE float16_t(float x) : x(x){};
GT_INLINE float16_t(storage_type x) : x(x){};

GT_INLINE const float16_t& operator=(const float f)
{
x = f;
return *this;
}
GT_INLINE compute_type Get() const { return static_cast<compute_type>(x); }

private:
storage_type x;
};

#define PROVIDE_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(op) \
GT_INLINE float16_t operator op(const float16_t& lhs, const float16_t& rhs) \
{ \
return float16_t(lhs.Get() op rhs.Get()); \
}

#define PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(op, fp_type) \
\
GT_INLINE fp_type operator op(const float16_t& lhs, const fp_type& rhs) \
{ \
return static_cast<fp_type>(lhs.Get()) op rhs; \
} \
\
GT_INLINE fp_type operator op(const fp_type& lhs, const float16_t& rhs) \
{ \
return lhs op static_cast<fp_type>(rhs.Get()); \
}

PROVIDE_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(+);
PROVIDE_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(-);
PROVIDE_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(*);
PROVIDE_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(/);

PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, float);
PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, float);
PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, float);
PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, float);

PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(+, double);
PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(-, double);
PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(*, double);
PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR(/, double);

#define PROVIDE_FLOAT16T_COMPARISON_OPERATOR(op) \
GT_INLINE bool operator op(const float16_t& lhs, const float16_t& rhs) \
{ \
return lhs.Get() op rhs.Get(); \
}

#define PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(op, fp_type) \
\
GT_INLINE bool operator op(const float16_t& lhs, const fp_type& rhs) \
{ \
return static_cast<fp_type>(lhs.Get()) op rhs; \
} \
\
GT_INLINE bool operator op(const fp_type& lhs, const float16_t& rhs) \
{ \
return lhs op static_cast<fp_type>(rhs.Get()); \
}

PROVIDE_FLOAT16T_COMPARISON_OPERATOR(==);
PROVIDE_FLOAT16T_COMPARISON_OPERATOR(!=);
PROVIDE_FLOAT16T_COMPARISON_OPERATOR(<);
PROVIDE_FLOAT16T_COMPARISON_OPERATOR(<=);
PROVIDE_FLOAT16T_COMPARISON_OPERATOR(>);
PROVIDE_FLOAT16T_COMPARISON_OPERATOR(>=);

PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(==, float);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(!=, float);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(<, float);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(<=, float);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(>, float);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(>=, float);

PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(==, double);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(!=, double);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(<, double);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(<=, double);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(>, double);
PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR(>=, double);

std::ostream& operator<<(std::ostream& s, const float16_t& h)
{
s << static_cast<float>(h.Get());
return s;
}

} // namespace gt

#undef PROVIDE_FLOAT16T_BINARY_ARITHMETIC_OPERATOR
#undef PROVIDE_MIXED_FLOAT16T_BINARY_ARITHMETIC_OPERATOR
#undef PROVIDE_FLOAT16T_COMPARISON_OPERATOR
#undef PROVIDE_MIXED_FLOAT16T_COMPARISON_OPERATOR

#endif
4 changes: 4 additions & 0 deletions include/gtensor/gtensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include "operator.h"
#include "space.h"

#if defined(GTENSOR_ENABLE_FP16)
#include "float16_t.h"
#endif

namespace gt
{

Expand Down
4 changes: 4 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,7 @@ if (GTENSOR_ENABLE_FFT)
add_gtensor_test(test_fft)
target_link_libraries(test_fft gtfft)
endif()

if (GTENSOR_ENABLE_FP16)
add_gtensor_test(test_float16_t)
endif()
Loading

0 comments on commit 20d2ce5

Please sign in to comment.