From d94bc0d879b933721c777feb96142ec44c0dfdec Mon Sep 17 00:00:00 2001 From: liufengwei0103 <2472937968@qq.com> Date: Wed, 6 Apr 2022 14:35:19 +0800 Subject: [PATCH] fix bug about randn because get wrong device id (#7896) * fix bug about randn because get wrong device id * revert cudaGetDevice in MakeDeviceKey * create generator based on device index * create generator using stream device id * add more test case * refine * fix device id wrong in random series op * fix more * add test case * add test case * add test with generator * refine * auto format by CI * auto format by CI * fix conflict * fix conflict * fix test case error * fix test case error Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: oneflow-ci-bot --- .../kernels/distributions/normal_distribution.cu | 4 +++- .../kernels/distributions/uniform_distribution.cu | 4 +++- .../distributions/uniform_int_distribution.cu | 4 +++- oneflow/user/kernels/dropout_kernel.cu | 11 +++++++---- oneflow/user/kernels/one_embedding_kernels.cu | 3 ++- oneflow/user/kernels/random_mask_generator.h | 10 ++++++---- oneflow/user/kernels/random_mask_like_kernel.h | 8 ++++++-- oneflow/user/kernels/randperm_kernel.cu | 8 +++++--- python/oneflow/test/modules/test_dropout.py | 9 +++++++++ python/oneflow/test/modules/test_rand.py | 8 ++++++++ python/oneflow/test/modules/test_randint.py | 8 ++++++++ python/oneflow/test/modules/test_randn.py | 13 +++++++++++++ python/oneflow/test/modules/test_randperm.py | 8 ++++++++ 13 files changed, 81 insertions(+), 17 deletions(-) diff --git a/oneflow/user/kernels/distributions/normal_distribution.cu b/oneflow/user/kernels/distributions/normal_distribution.cu index 6f2609b0228..b9909d32849 100644 --- a/oneflow/user/kernels/distributions/normal_distribution.cu +++ b/oneflow/user/kernels/distributions/normal_distribution.cu @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/user/kernels/distributions/normal_distribution.h" #include "oneflow/core/common/data_type.h" +#include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { @@ -51,7 +52,8 @@ void NormalDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); - auto gen = CHECK_JUST(generator->Get()); + const auto device_index = stream->device()->device_index(); + auto gen = CHECK_JUST(generator->Get(device_index)); int32_t block_num = gen->max_block_num(); int32_t thread_num = gen->max_thread_num(); auto* curand_states = gen->curand_states(); diff --git a/oneflow/user/kernels/distributions/uniform_distribution.cu b/oneflow/user/kernels/distributions/uniform_distribution.cu index 353651d8da6..e2ed2059de8 100644 --- a/oneflow/user/kernels/distributions/uniform_distribution.cu +++ b/oneflow/user/kernels/distributions/uniform_distribution.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/common/data_type.h" #include "oneflow/user/kernels/distributions/uniform_distribution.h" +#include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { @@ -56,7 +57,8 @@ void UniformDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); - auto gen = CHECK_JUST(generator->Get()); + const auto device_index = stream->device()->device_index(); + auto gen = CHECK_JUST(generator->Get(device_index)); int32_t block_num = gen->max_block_num(); int32_t thread_num = gen->max_thread_num(); auto* curand_states = gen->curand_states(); diff --git a/oneflow/user/kernels/distributions/uniform_int_distribution.cu b/oneflow/user/kernels/distributions/uniform_int_distribution.cu index 4947183116f..d19940d4bc8 100644 --- a/oneflow/user/kernels/distributions/uniform_int_distribution.cu +++ b/oneflow/user/kernels/distributions/uniform_int_distribution.cu @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/user/kernels/distributions/uniform_int_distribution.h" +#include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { @@ -49,7 +50,8 @@ void UniformIntDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); - auto gen = CHECK_JUST(generator->Get()); + const auto device_index = stream->device()->device_index(); + auto gen = CHECK_JUST(generator->Get(device_index)); int32_t block_num = gen->max_block_num(); int32_t thread_num = gen->max_thread_num(); auto* curand_states = gen->curand_states(); diff --git a/oneflow/user/kernels/dropout_kernel.cu b/oneflow/user/kernels/dropout_kernel.cu index da0d3661bc1..b1ac8e577e6 100644 --- a/oneflow/user/kernels/dropout_kernel.cu +++ b/oneflow/user/kernels/dropout_kernel.cu @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/core/kernel/cuda_graph_support.h" +#include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/device/cuda_pseudo_bfloat16.h" namespace oneflow { @@ -421,8 +422,10 @@ class DropoutKernelGPU final : public user_op::OpKernel, public user_op::CudaGra CHECK_NOTNULL(fused_dropout_kernel_state); const auto& generator = fused_dropout_kernel_state->generator(); CHECK_NOTNULL(generator); + auto* stream = ctx->stream(); + const auto device_index = stream->device()->device_index(); std::shared_ptr cuda_generator = - CHECK_JUST(generator->Get()); + CHECK_JUST(generator->Get(device_index)); uint64_t seed = cuda_generator->current_seed(); const float rate = ctx->Attr("rate"); @@ -433,12 +436,12 @@ class DropoutKernelGPU final : public user_op::OpKernel, public user_op::CudaGra if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* addend = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); DispatchTail( - ctx->stream(), seed, cuda_gen_state, in->shape().elem_cnt(), rate, scale, + stream, seed, cuda_gen_state, in->shape().elem_cnt(), rate, scale, reinterpret_cast(in->dptr()), reinterpret_cast(mask->mut_dptr()), reinterpret_cast(addend->dptr()), reinterpret_cast(out->mut_dptr())); } else { - DispatchTail(ctx->stream(), seed, cuda_gen_state, in->shape().elem_cnt(), rate, - scale, reinterpret_cast(in->dptr()), + DispatchTail(stream, seed, cuda_gen_state, in->shape().elem_cnt(), rate, scale, + reinterpret_cast(in->dptr()), reinterpret_cast(mask->mut_dptr()), nullptr, reinterpret_cast(out->mut_dptr())); } diff --git a/oneflow/user/kernels/one_embedding_kernels.cu b/oneflow/user/kernels/one_embedding_kernels.cu index ea8d7118d87..b8091e99c01 100644 --- a/oneflow/user/kernels/one_embedding_kernels.cu +++ b/oneflow/user/kernels/one_embedding_kernels.cu @@ -23,6 +23,7 @@ limitations under the License. #include "oneflow/core/cuda/atomic.cuh" #include "oneflow/core/ep/include/primitive/copy_nd.h" #include "oneflow/core/ep/include/primitive/cast.h" +#include "oneflow/core/ep/include/device.h" namespace oneflow { @@ -248,7 +249,7 @@ void LookupAndInitMissing(ep::Stream* stream, EmbeddingKernelState* embeddi const auto& generator = embedding_state->generator(); CHECK_NOTNULL(generator); std::shared_ptr cuda_generator = - CHECK_JUST(generator->template Get()); + CHECK_JUST(generator->template Get(stream->device()->device_index())); uint64_t seed = cuda_generator->current_seed(); one::CUDAGeneratorState* cuda_gen_state = cuda_generator->cuda_gen_state(); embedding::KeyValueStore* store = embedding_state->KeyValueStore(); diff --git a/oneflow/user/kernels/random_mask_generator.h b/oneflow/user/kernels/random_mask_generator.h index 391feab6468..01f65d8aaa6 100644 --- a/oneflow/user/kernels/random_mask_generator.h +++ b/oneflow/user/kernels/random_mask_generator.h @@ -33,8 +33,9 @@ template<> class RandomMaskGenerator final { public: OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator); - RandomMaskGenerator(const std::shared_ptr& generator) { - generator_ = CHECK_JUST(generator->Get()); + RandomMaskGenerator(const std::shared_ptr& generator, + const int device_index = -1) { + generator_ = CHECK_JUST(generator->Get(device_index)); } ~RandomMaskGenerator() = default; @@ -49,8 +50,9 @@ template<> class RandomMaskGenerator final { public: OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator); - RandomMaskGenerator(const std::shared_ptr& generator) { - generator_ = CHECK_JUST(generator->Get()); + RandomMaskGenerator(const std::shared_ptr& generator, + const int device_index = -1) { + generator_ = CHECK_JUST(generator->Get(device_index)); } ~RandomMaskGenerator() = default; diff --git a/oneflow/user/kernels/random_mask_like_kernel.h b/oneflow/user/kernels/random_mask_like_kernel.h index caeadbac5fb..352fa43f3ff 100644 --- a/oneflow/user/kernels/random_mask_like_kernel.h +++ b/oneflow/user/kernels/random_mask_like_kernel.h @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/user/kernels/random_mask_generator.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" +#include "oneflow/core/ep/include/device.h" namespace oneflow { @@ -59,8 +60,11 @@ class RandomMaskLikeKernel final : public user_op::OpKernel, public user_op::Cud CHECK_NOTNULL(random_mask_like_state); const auto& generator = random_mask_like_state->generator(); CHECK_NOTNULL(generator); - auto random_mask_like_gen = std::make_shared>(generator); - random_mask_like_gen->Generate(ctx->stream(), elem_cnt, ctx->Attr("rate"), mask); + auto* stream = ctx->stream(); + const auto device_index = stream->device()->device_index(); + auto random_mask_like_gen = + std::make_shared>(generator, device_index); + random_mask_like_gen->Generate(stream, elem_cnt, ctx->Attr("rate"), mask); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/kernels/randperm_kernel.cu b/oneflow/user/kernels/randperm_kernel.cu index ab634ef62dd..e98e6f6d60b 100644 --- a/oneflow/user/kernels/randperm_kernel.cu +++ b/oneflow/user/kernels/randperm_kernel.cu @@ -24,6 +24,7 @@ limitations under the License. #include "oneflow/user/kernels/arange_kernel_util.h" #include "oneflow/user/kernels/radix_sort.cuh" #include "oneflow/user/kernels/distributions/common.h" +#include "oneflow/core/ep/include/device.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { @@ -78,7 +79,9 @@ class GpuRandPermKernel final : public user_op::OpKernel { auto* distribution_state = dynamic_cast(state); CHECK_NOTNULL(distribution_state); const auto& generator = distribution_state->generator(); - const auto& gpu_generator = CHECK_JUST(generator->Get()); + auto* stream = ctx->stream(); + const auto device_index = stream->device()->device_index(); + const auto& gpu_generator = CHECK_JUST(generator->Get(device_index)); CHECK_NOTNULL(generator); int32_t block_num = gpu_generator->max_block_num(); @@ -99,8 +102,7 @@ class GpuRandPermKernel final : public user_op::OpKernel { reinterpret_cast(reinterpret_cast(value_base) + indices_aligned_bytes); size_t temp_storage_bytes = GetCubSortPairsTempStorageSize(n); - GeneKeysAndValues<<stream()->As()->cuda_stream()>>>( + GeneKeysAndValues<<As()->cuda_stream()>>>( n, value_base, key_base, curand_states); auto err = cub::DeviceRadixSort::SortPairs( diff --git a/python/oneflow/test/modules/test_dropout.py b/python/oneflow/test/modules/test_dropout.py index 1658023a4dc..8badc4b42b4 100644 --- a/python/oneflow/test/modules/test_dropout.py +++ b/python/oneflow/test/modules/test_dropout.py @@ -329,5 +329,14 @@ def autotest_0dim_dropout_eval(test_case): return m(x) +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n2d() +class TestDropoutOnNonDefaultDevice(flow.unittest.TestCase): + def test_non_default_device(test_case): + x = flow.tensor([2, 3], dtype=flow.float, device="cuda:1") + y = flow._C.dropout(x) + test_case.assertEqual(y.device, flow.device("cuda:1")) + + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_rand.py b/python/oneflow/test/modules/test_rand.py index 924509dd7f2..b822d5b3ec4 100644 --- a/python/oneflow/test/modules/test_rand.py +++ b/python/oneflow/test/modules/test_rand.py @@ -114,5 +114,13 @@ def test_cases(test_case): arg[0](test_case, *arg[1:]) +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n2d() +class TestRandOnNonDefaultDevice(flow.unittest.TestCase): + def test_non_default_device(test_case): + x = flow.rand(2, 3, device="cuda:1") + test_case.assertEqual(x.device, flow.device("cuda:1")) + + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_randint.py b/python/oneflow/test/modules/test_randint.py index f6102dd3e08..e9c1ab4e6c7 100644 --- a/python/oneflow/test/modules/test_randint.py +++ b/python/oneflow/test/modules/test_randint.py @@ -141,5 +141,13 @@ def test_0rank_randint(test_case): arg[0](test_case, *arg[1:]) +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n2d() +class TestRandintOnNonDefaultDevice(flow.unittest.TestCase): + def test_non_default_device(test_case): + x = flow.randint(low=1, high=2, size=flow.Size((2, 3)), device="cuda:1") + test_case.assertEqual(x.device, flow.device("cuda:1")) + + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_randn.py b/python/oneflow/test/modules/test_randn.py index bd63cb2b360..9c9967ae17a 100644 --- a/python/oneflow/test/modules/test_randn.py +++ b/python/oneflow/test/modules/test_randn.py @@ -125,5 +125,18 @@ def test_0d_randn(test_case): arg[0](test_case, *arg[1:]) +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n2d() +class TestRandnOnNonDefaultDevice(flow.unittest.TestCase): + def test_non_default_device(test_case): + x = flow.randn(2, 3, device="cuda:1") + test_case.assertEqual(x.device, flow.device("cuda:1")) + + def test_with_generator(test_case): + gen = flow.Generator("cuda") + x = flow.randn(2, 3, device="cuda", generator=gen) + test_case.assertEqual(x.device, flow.device(f"cuda:{flow.env.get_rank()}")) + + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_randperm.py b/python/oneflow/test/modules/test_randperm.py index 644978ff375..7a60d1e4bd1 100644 --- a/python/oneflow/test/modules/test_randperm.py +++ b/python/oneflow/test/modules/test_randperm.py @@ -127,5 +127,13 @@ def test_auto_0(test_case): return y +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n2d() +class TestRandpermOnNonDefaultDevice(flow.unittest.TestCase): + def test_non_default_device(test_case): + x = flow.randperm(3, device="cuda:1") + test_case.assertEqual(x.device, flow.device("cuda:1")) + + if __name__ == "__main__": unittest.main()