Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support half precision sigmoid activation #378

Merged
merged 5 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions include/cutlass/epilogue/thread/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,7 @@ template <typename T>
struct Sigmoid {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return T(1) / (T(1) + exp(-scalar));
}
};

template <>
struct Sigmoid<float> {
CUTLASS_HOST_DEVICE
float operator()(float const &scalar) const {
return 1.0f / (1.0f + expf(-scalar));
return T(1) / (T(1) + fast_exp(-scalar));
}
};

Expand All @@ -126,6 +118,30 @@ struct Sigmoid<Array<T, N> > {
}
};

template <int N>
struct Sigmoid<Array<half_t, N>> {
using T = half_t;

CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const& z) const {
plus<Array<T, N>> add;

#if defined(CUTLASS_USE_TANH_FOR_SIGMOID)
multiplies<Array<T, N>> mul;
fast_tanh_op<Array<T, N>> tanh;
return mul(add(tanh(mul(z, cutlass::constants::half<T>())), cutlass::constants::one<T>()),
cutlass::constants::half<T>());
#else
divides<Array<T, N>> div;
negate<Array<T, N>> neg;
fast_exp_op<Array<T, N>> fast_exp;
return div(cutlass::constants::one<T>(),
add(cutlass::constants::one<T>(),
fast_exp(neg(z))));
#endif
}
};

// SiLu (swish) operator introduced by Elfwing et al. in the following paper
// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017)
// https://arxiv.org/pdf/1702.03118.pdf
Expand Down
66 changes: 65 additions & 1 deletion include/cutlass/fast_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,12 +705,21 @@ float fast_exp(float x) {
CUTLASS_HOST_DEVICE
double fast_exp(double x) {
#if defined(__CUDA_ARCH__)
return ::exp(x);
return ::expf(x);
#else
return std::exp(x);
#endif
}

CUTLASS_HOST_DEVICE
float fast_exp(half_t x) {
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
return ::hexp(x.to_half());
#else
return fast_exp(float(x));
#endif
}

CUTLASS_HOST_DEVICE
float fast_log(float x) {
#if defined(__CUDA_ARCH__)
Expand Down Expand Up @@ -767,6 +776,61 @@ half_t fast_tanh(half_t x) {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
struct fast_exp_op {
CUTLASS_HOST_DEVICE
T operator()(T const &rhs) const {
return fast_exp(rhs);
}
};

#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDA_ARCH__ >= 750)
template <int N>
struct fast_exp_op<Array<half_t, N>> {
CUTLASS_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &rhs) const {

Array<half_t, N> result;

// use x2 specialization
__half2 const *in = reinterpret_cast<__half2 const *>(&rhs);
__half2 *out = reinterpret_cast<__half2 *>(&result);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
out[i] = ::h2exp(in[i]);
}

// residual
if (N % 2) {
half_t last = rhs[N - 1];
result[N - 1] = half_t(::hexp(last.to_half()));
}

return result;
}
};
#endif // #if defined(__CUDA_ARCH__)

template <typename T, int N>
struct fast_exp_op<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {

fast_exp_op<T> fast_op;
Array<T, N> y;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = fast_op(rhs[i]);
}

return y;
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
struct fast_tanh_op {
CUTLASS_HOST_DEVICE
Expand Down