Skip to content

Commit

Permalink
Enable fusedlayernorm extension (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
lcskrishna authored May 7, 2020
1 parent 3ccdd63 commit 2d0f9cf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
14 changes: 10 additions & 4 deletions csrc/layer_norm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
cuWelfordOnlineSum<float>(curr.x,mu,sigma2,count);
cuWelfordOnlineSum<float>(curr.y,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
Expand Down Expand Up @@ -230,9 +230,15 @@ void cuWelfordMuSigma2(
template<typename U> U rsqrt(U v) {
return U(1) / sqrt(v);
}
#if defined __HIP_PLATFORM_HCC__
__device__ float rsqrt(float v) {
return rsqrtf(v);
}
#else
template<> float rsqrt(float v) {
return rsqrtf(v);
}
#endif
template<> double rsqrt(double v) {
return rsqrt(v);
}
Expand Down Expand Up @@ -293,7 +299,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
Expand Down Expand Up @@ -531,7 +537,7 @@ void cuComputeGradInput(
const T* gamma,
T* grad_input)
{
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
Expand Down
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
'-O3',
'--use_fast_math'] + version_dependent_macros}))
else:
print ("INFO: Skipping FusedLayerNorm extension.")
print ("INFO: Building FusedLayerNorm extension.")
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['csrc/layer_norm_cuda.cpp',
'csrc/hip/layer_norm_hip_kernel.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc' : []}))

if not is_rocm_pytorch:
ext_modules.append(
Expand Down

0 comments on commit 2d0f9cf

Please sign in to comment.