Skip to content

Commit

Permalink
fix bug about randn because get wrong device id (#7896)
Browse files Browse the repository at this point in the history
* fix bug about randn because get wrong device id

* revert cudaGetDevice in MakeDeviceKey

* create generator based on device index

* create generator using stream device id

* add more test case

* refine

* fix device id wrong in random series op

* fix more

* add test case

* add test case

* add test with generator

* refine

* auto format by CI

* auto format by CI

* fix conflict

* fix conflict

* fix test case error

* fix test case error

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
3 people authored Apr 6, 2022
1 parent 7fe29cb commit d94bc0d
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 17 deletions.
4 changes: 3 additions & 1 deletion oneflow/user/kernels/distributions/normal_distribution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include "oneflow/user/kernels/distributions/normal_distribution.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/ep/include/device.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"

namespace oneflow {
Expand Down Expand Up @@ -51,7 +52,8 @@ void NormalDistribution<DeviceType::kCUDA, T>::operator()(
ep::Stream* stream, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
const auto device_index = stream->device()->device_index();
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>(device_index));
int32_t block_num = gen->max_block_num();
int32_t thread_num = gen->max_thread_num();
auto* curand_states = gen->curand_states();
Expand Down
4 changes: 3 additions & 1 deletion oneflow/user/kernels/distributions/uniform_distribution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/user/kernels/distributions/uniform_distribution.h"
#include "oneflow/core/ep/include/device.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"

namespace oneflow {
Expand Down Expand Up @@ -56,7 +57,8 @@ void UniformDistribution<DeviceType::kCUDA, T>::operator()(
ep::Stream* stream, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
const auto device_index = stream->device()->device_index();
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>(device_index));
int32_t block_num = gen->max_block_num();
int32_t thread_num = gen->max_thread_num();
auto* curand_states = gen->curand_states();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/user/kernels/distributions/uniform_int_distribution.h"
#include "oneflow/core/ep/include/device.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"

namespace oneflow {
Expand Down Expand Up @@ -49,7 +50,8 @@ void UniformIntDistribution<DeviceType::kCUDA, T>::operator()(
ep::Stream* stream, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
const auto device_index = stream->device()->device_index();
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>(device_index));
int32_t block_num = gen->max_block_num();
int32_t thread_num = gen->max_thread_num();
auto* curand_states = gen->curand_states();
Expand Down
11 changes: 7 additions & 4 deletions oneflow/user/kernels/dropout_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/cuda/atomic.cuh"
#include "oneflow/user/kernels/dropout_kernel.h"
#include "oneflow/core/kernel/cuda_graph_support.h"
#include "oneflow/core/ep/include/device.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/device/cuda_pseudo_bfloat16.h"
namespace oneflow {
Expand Down Expand Up @@ -421,8 +422,10 @@ class DropoutKernelGPU final : public user_op::OpKernel, public user_op::CudaGra
CHECK_NOTNULL(fused_dropout_kernel_state);
const auto& generator = fused_dropout_kernel_state->generator();
CHECK_NOTNULL(generator);
auto* stream = ctx->stream();
const auto device_index = stream->device()->device_index();
std::shared_ptr<one::CUDAGeneratorImpl> cuda_generator =
CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>(device_index));
uint64_t seed = cuda_generator->current_seed();
const float rate = ctx->Attr<float>("rate");
Expand All @@ -433,12 +436,12 @@ class DropoutKernelGPU final : public user_op::OpKernel, public user_op::CudaGra
if (ctx->has_input("_add_to_output", 0)) {
const user_op::Tensor* addend = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0);
DispatchTail<T, true>(
ctx->stream(), seed, cuda_gen_state, in->shape().elem_cnt(), rate, scale,
stream, seed, cuda_gen_state, in->shape().elem_cnt(), rate, scale,
reinterpret_cast<const T*>(in->dptr()), reinterpret_cast<bool*>(mask->mut_dptr()),
reinterpret_cast<const T*>(addend->dptr()), reinterpret_cast<T*>(out->mut_dptr()));
} else {
DispatchTail<T, false>(ctx->stream(), seed, cuda_gen_state, in->shape().elem_cnt(), rate,
scale, reinterpret_cast<const T*>(in->dptr()),
DispatchTail<T, false>(stream, seed, cuda_gen_state, in->shape().elem_cnt(), rate, scale,
reinterpret_cast<const T*>(in->dptr()),
reinterpret_cast<bool*>(mask->mut_dptr()), nullptr,
reinterpret_cast<T*>(out->mut_dptr()));
}
Expand Down
3 changes: 2 additions & 1 deletion oneflow/user/kernels/one_embedding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "oneflow/core/cuda/atomic.cuh"
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/include/primitive/cast.h"
#include "oneflow/core/ep/include/device.h"

namespace oneflow {

Expand Down Expand Up @@ -248,7 +249,7 @@ void LookupAndInitMissing(ep::Stream* stream, EmbeddingKernelState<IDX>* embeddi
const auto& generator = embedding_state->generator();
CHECK_NOTNULL(generator);
std::shared_ptr<one::CUDAGeneratorImpl> cuda_generator =
CHECK_JUST(generator->template Get<one::CUDAGeneratorImpl>());
CHECK_JUST(generator->template Get<one::CUDAGeneratorImpl>(stream->device()->device_index()));
uint64_t seed = cuda_generator->current_seed();
one::CUDAGeneratorState* cuda_gen_state = cuda_generator->cuda_gen_state();
embedding::KeyValueStore* store = embedding_state->KeyValueStore();
Expand Down
10 changes: 6 additions & 4 deletions oneflow/user/kernels/random_mask_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ template<>
class RandomMaskGenerator<DeviceType::kCPU> final {
public:
OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator);
RandomMaskGenerator(const std::shared_ptr<one::Generator>& generator) {
generator_ = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>());
RandomMaskGenerator(const std::shared_ptr<one::Generator>& generator,
const int device_index = -1) {
generator_ = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>(device_index));
}
~RandomMaskGenerator() = default;

Expand All @@ -49,8 +50,9 @@ template<>
class RandomMaskGenerator<DeviceType::kCUDA> final {
public:
OF_DISALLOW_COPY_AND_MOVE(RandomMaskGenerator);
RandomMaskGenerator(const std::shared_ptr<one::Generator>& generator) {
generator_ = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
RandomMaskGenerator(const std::shared_ptr<one::Generator>& generator,
const int device_index = -1) {
generator_ = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>(device_index));
}
~RandomMaskGenerator() = default;

Expand Down
8 changes: 6 additions & 2 deletions oneflow/user/kernels/random_mask_like_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/user/kernels/random_mask_generator.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/cuda_graph_support.h"
#include "oneflow/core/ep/include/device.h"

namespace oneflow {

Expand Down Expand Up @@ -59,8 +60,11 @@ class RandomMaskLikeKernel final : public user_op::OpKernel, public user_op::Cud
CHECK_NOTNULL(random_mask_like_state);
const auto& generator = random_mask_like_state->generator();
CHECK_NOTNULL(generator);
auto random_mask_like_gen = std::make_shared<RandomMaskGenerator<device_type>>(generator);
random_mask_like_gen->Generate(ctx->stream(), elem_cnt, ctx->Attr<float>("rate"), mask);
auto* stream = ctx->stream();
const auto device_index = stream->device()->device_index();
auto random_mask_like_gen =
std::make_shared<RandomMaskGenerator<device_type>>(generator, device_index);
random_mask_like_gen->Generate(stream, elem_cnt, ctx->Attr<float>("rate"), mask);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
Expand Down
8 changes: 5 additions & 3 deletions oneflow/user/kernels/randperm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "oneflow/user/kernels/arange_kernel_util.h"
#include "oneflow/user/kernels/radix_sort.cuh"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/core/ep/include/device.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"

namespace oneflow {
Expand Down Expand Up @@ -78,7 +79,9 @@ class GpuRandPermKernel final : public user_op::OpKernel {
auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);
CHECK_NOTNULL(distribution_state);
const auto& generator = distribution_state->generator();
const auto& gpu_generator = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
auto* stream = ctx->stream();
const auto device_index = stream->device()->device_index();
const auto& gpu_generator = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>(device_index));
CHECK_NOTNULL(generator);

int32_t block_num = gpu_generator->max_block_num();
Expand All @@ -99,8 +102,7 @@ class GpuRandPermKernel final : public user_op::OpKernel {
reinterpret_cast<void*>(reinterpret_cast<char*>(value_base) + indices_aligned_bytes);
size_t temp_storage_bytes = GetCubSortPairsTempStorageSize<int32_t>(n);

GeneKeysAndValues<<<block_num, thread_num, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
GeneKeysAndValues<<<block_num, thread_num, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
n, value_base, key_base, curand_states);
auto err = cub::DeviceRadixSort::SortPairs(
Expand Down
9 changes: 9 additions & 0 deletions python/oneflow/test/modules/test_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,5 +329,14 @@ def autotest_0dim_dropout_eval(test_case):
return m(x)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestDropoutOnNonDefaultDevice(flow.unittest.TestCase):
def test_non_default_device(test_case):
x = flow.tensor([2, 3], dtype=flow.float, device="cuda:1")
y = flow._C.dropout(x)
test_case.assertEqual(y.device, flow.device("cuda:1"))


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions python/oneflow/test/modules/test_rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,13 @@ def test_cases(test_case):
arg[0](test_case, *arg[1:])


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestRandOnNonDefaultDevice(flow.unittest.TestCase):
def test_non_default_device(test_case):
x = flow.rand(2, 3, device="cuda:1")
test_case.assertEqual(x.device, flow.device("cuda:1"))


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions python/oneflow/test/modules/test_randint.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,13 @@ def test_0rank_randint(test_case):
arg[0](test_case, *arg[1:])


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestRandintOnNonDefaultDevice(flow.unittest.TestCase):
def test_non_default_device(test_case):
x = flow.randint(low=1, high=2, size=flow.Size((2, 3)), device="cuda:1")
test_case.assertEqual(x.device, flow.device("cuda:1"))


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions python/oneflow/test/modules/test_randn.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,18 @@ def test_0d_randn(test_case):
arg[0](test_case, *arg[1:])


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestRandnOnNonDefaultDevice(flow.unittest.TestCase):
def test_non_default_device(test_case):
x = flow.randn(2, 3, device="cuda:1")
test_case.assertEqual(x.device, flow.device("cuda:1"))

def test_with_generator(test_case):
gen = flow.Generator("cuda")
x = flow.randn(2, 3, device="cuda", generator=gen)
test_case.assertEqual(x.device, flow.device(f"cuda:{flow.env.get_rank()}"))


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions python/oneflow/test/modules/test_randperm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,13 @@ def test_auto_0(test_case):
return y


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestRandpermOnNonDefaultDevice(flow.unittest.TestCase):
def test_non_default_device(test_case):
x = flow.randperm(3, device="cuda:1")
test_case.assertEqual(x.device, flow.device("cuda:1"))


if __name__ == "__main__":
unittest.main()

0 comments on commit d94bc0d

Please sign in to comment.