Skip to content

Commit

Permalink
fix (#7978)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
guo-ran and mergify[bot] authored Apr 7, 2022
1 parent 85964ca commit 58d91ed
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
8 changes: 4 additions & 4 deletions oneflow/core/cuda/layer_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct MaxOp {
template<template<typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask, thread_group_width));
}
return val;
}
Expand Down Expand Up @@ -210,9 +210,9 @@ __inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T threa
*m2 = thread_m2;
*count = thread_count;
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
T b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
T b_count = __shfl_down_sync(0xffffffff, *count, mask);
T b_mean = __shfl_down_sync(0xffffffff, *mean, mask, thread_group_width);
T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask, thread_group_width);
T b_count = __shfl_down_sync(0xffffffff, *count, mask, thread_group_width);
WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);
}
}
Expand Down
5 changes: 1 addition & 4 deletions python/oneflow/test/expensive/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def test_layernorm(test_case):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])

@unittest.skip("TODO: guoran, different behavior of __shfl_sync in sm_61")
@autotest(n=20, auto_backward=True, rtol=1.0, atol=1.0)
@autotest(n=20, auto_backward=True, rtol=1e-3, atol=1e-3)
def test_layernorm_with_random_data_warp(test_case):
device = "cuda"
channel = random(1, 32).to(int)
Expand All @@ -161,7 +160,6 @@ def get_random_norm_shape():
y = m(x)
return y

@unittest.skip("TODO: guoran, different behavior of __shfl_sync in sm_61")
@autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3)
def test_layernorm_with_random_data_shared_mem(test_case):
device = "cuda"
Expand All @@ -181,7 +179,6 @@ def get_random_norm_shape():
y = m(x)
return y

@unittest.skip("TODO: guoran, different behavior of __shfl_sync in sm_61")
@autotest(n=5, auto_backward=True, rtol=1e-3, atol=1e-3)
def test_layernorm_with_random_data_uncached(test_case):
device = "cuda"
Expand Down

0 comments on commit 58d91ed

Please sign in to comment.