Skip to content

Commit

Permalink
Fix LayerNorm op on ROCm (#36)
Browse files Browse the repository at this point in the history
* fix warp size in WARP_SHFL* in layernorm

* enable fused_layer_norm tests on ROCm
  • Loading branch information
ashishfarmer authored Nov 4, 2020
1 parent e9c43d6 commit 7eed38a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
24 changes: 12 additions & 12 deletions csrc/layer_norm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
U muB = WARP_SHFL(mu, srcLaneB, 32);
U countB = WARP_SHFL(count, srcLaneB, 32);
U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
Expand Down Expand Up @@ -126,8 +126,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/U(n2), 0, 32);
}
}
}
Expand Down Expand Up @@ -183,9 +183,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
float muB = WARP_SHFL(mu, srcLaneB, 32);
float countB = WARP_SHFL(count, srcLaneB, 32);
float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
Expand Down Expand Up @@ -221,8 +221,8 @@ void cuWelfordMuSigma2(
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
mu = WARP_SHFL(mu, 0, 32);
sigma2 = WARP_SHFL(sigma2/float(n2), 0, 32);
}
}
}
Expand Down Expand Up @@ -581,8 +581,8 @@ void cuComputeGradInput(
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32);
}
// inter-warp reductions
if (blockDim.y > 1) {
Expand Down
1 change: 0 additions & 1 deletion tests/L0/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
test_dirs = ["run_amp", "run_fp16util", "run_optimizers", "run_fused_layer_norm", "run_pyprof_nvtx", "run_pyprof_data", "run_mlp"]

ROCM_BLACKLIST = [
'run_fused_layer_norm',
'run_pyprof_nvtx',
'run_pyprof_data',
]
Expand Down

0 comments on commit 7eed38a

Please sign in to comment.