From a012cb847b750c3ca5577b4e5a243b489b1ab0e8 Mon Sep 17 00:00:00 2001 From: binbinHan Date: Sat, 9 Apr 2022 11:07:47 +0800 Subject: [PATCH] add_manual_seed_all_api (#7957) * add_manual_seed_all_api * Update conf.py * refine * add test case * auto format by CI * Update random_generator.cpp * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- docs/source/cuda.rst | 2 + docs/source/oneflow.rst | 5 ++ .../api/python/framework/random_generator.cpp | 23 +++++++- oneflow/api/python/functional/common.cpp | 9 +++ oneflow/api/python/functional/common.h | 3 + oneflow/core/framework/random_generator.cpp | 39 +++++++++++- oneflow/core/framework/random_generator.h | 6 +- .../core/framework/random_generator_impl.cpp | 7 ++- python/oneflow/__init__.py | 2 + python/oneflow/cuda/__init__.py | 35 +++++++++++ python/oneflow/framework/docstr/constant.py | 2 +- python/oneflow/framework/docstr/math_ops.py | 6 +- python/oneflow/framework/generator.py | 56 ++++++++++++++++-- python/oneflow/nn/modules/empty.py | 2 +- .../oneflow/test/misc/test_manual_seed_api.py | 59 +++++++++++++++++++ 15 files changed, 240 insertions(+), 16 deletions(-) create mode 100644 python/oneflow/test/misc/test_manual_seed_api.py diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index c5aea292247..4c31404134e 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -7,6 +7,8 @@ ONEFLOW.CUDA :members: is_available, device_count, current_device, + manual_seed_all, + manual_seed, HalfTensor, FloatTensor, DoubleTensor, diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index b56973057ad..518f7092214 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -179,6 +179,11 @@ oneflow CharTensor, IntTensor, LongTensor, + seed, + manual_seed, + initial_seed, + get_rng_state, + set_rng_state, isnan, isinf diff --git a/oneflow/api/python/framework/random_generator.cpp b/oneflow/api/python/framework/random_generator.cpp index 5eb2c069153..6c622d89398 100644 --- a/oneflow/api/python/framework/random_generator.cpp +++ b/oneflow/api/python/framework/random_generator.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/framework/tensor.h" @@ -34,14 +35,28 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { .def(py::init([](const std::string& device_tag) { return CreateGenerator(device_tag).GetPtrOrThrow(); })) - .def("manual_seed", &one::Generator::set_current_seed) + .def("manual_seed", + [](const std::shared_ptr& generator, + const py::object& seed) -> Maybe { + int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr())); + generator->set_current_seed(seed_val); + return Maybe::Ok(); + }) .def("initial_seed", &one::Generator::current_seed) .def("seed", &one::Generator::seed) .def_property_readonly("device", &one::Generator::device) .def("get_state", &one::Generator::GetState) .def("set_state", &one::Generator::SetState); - m.def("manual_seed", [](uint64_t seed) { return one::ManualSeed(seed); }); + m.def("manual_seed", [](const py::object& seed) -> Maybe { + int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr())); + return one::ManualSeed(seed_val); + }); + m.def("manual_seed", + [](const py::object& seed, const std::string& device, int device_index) -> Maybe { + int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr())); + return one::ManualSeed(seed_val, device, device_index); + }); m.def("create_generator", &CreateGenerator); m.def("default_generator", [](const std::string& device_tag) -> Maybe { std::string device_name = ""; @@ -49,6 +64,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { JUST(ParsingDeviceTag(device_tag, &device_name, &device_index)); return one::DefaultGenerator(device_name, device_index); }); + m.def("ManualSeedAllCudaGenerator", [](const py::object& seed) -> Maybe { + int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr())); + return one::ManualSeedAllCudaGenerator(seed_val); + }); } } // namespace oneflow diff --git a/oneflow/api/python/functional/common.cpp b/oneflow/api/python/functional/common.cpp index b96fd93a4d0..1a5b39088f9 100644 --- a/oneflow/api/python/functional/common.cpp +++ b/oneflow/api/python/functional/common.cpp @@ -296,6 +296,15 @@ Maybe PyUnpackOpExpr(PyObject* obj) { return py::cast>(handle); } +// int64_t +Maybe PyUnpackLong(PyObject* py_obj) { + int overflow = -1; + long long val = PyLong_AsLongLongAndOverflow(py_obj, &overflow); + if (val == -1 && PyErr_Occurred()) { return Error::RuntimeError() << "Python exception occurs"; } + if (overflow != 0) { return Error::RuntimeError() << "Overflow when unpacking long"; } + return (int64_t)val; +} + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/api/python/functional/common.h b/oneflow/api/python/functional/common.h index 8caefc54743..8749f5ab228 100644 --- a/oneflow/api/python/functional/common.h +++ b/oneflow/api/python/functional/common.h @@ -169,6 +169,9 @@ Maybe PyUnpackTensorIndex(PyObject* obj); bool PyOpExprCheck(PyObject* obj); Maybe PyUnpackOpExpr(PyObject* obj); +// int64_t +Maybe PyUnpackLong(PyObject* py_obj); + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/core/framework/random_generator.cpp b/oneflow/core/framework/random_generator.cpp index ab2d84e1d87..72099ff8248 100644 --- a/oneflow/core/framework/random_generator.cpp +++ b/oneflow/core/framework/random_generator.cpp @@ -25,11 +25,35 @@ limitations under the License. namespace oneflow { namespace one { -Maybe ManualSeed(uint64_t seed) { - JUST(DefaultAutoGenerator())->set_current_seed(seed); +Maybe ManualSeed(uint64_t seed) { + const auto& default_auto_generator = JUST(DefaultAutoGenerator()); + default_auto_generator->set_current_seed(seed); + return default_auto_generator; +} + +Maybe ManualSeed(uint64_t seed, const std::string& device, int device_index) { + if (device == "cpu") { + JUST(DefaultCPUGenerator())->set_current_seed(seed); + } +#ifdef WITH_CUDA + else if (device == "cuda") { + JUST(DefaultCUDAGenerator(device_index))->set_current_seed(seed); + } +#endif // WITH_CUDA + else if (device == "auto") { + JUST(DefaultAutoGenerator())->set_current_seed(seed); + } else { + return Error::RuntimeError() << "Invalid device " << device + << " for making generator, please make sure the device is one of " + << PrintAvailableDevices(); + } return Maybe::Ok(); } +Maybe ManualSeed(uint64_t seed, DeviceType device, int device_index) { + return ManualSeed(seed, *JUST(DeviceTag4DeviceType(device)), device_index); +} + namespace detail { uint64_t GetNonDeterministicRandom() { @@ -100,6 +124,17 @@ Maybe MakeCUDAGenerator(int device_index) { } #endif // WITH_CUDA +Maybe ManualSeedAllCudaGenerator(uint64_t seed) { +#ifdef WITH_CUDA + static int device_count = GetCudaDeviceCount(); + FOR_RANGE(int, device_id, 0, device_count) { + const auto& cuda_gen = JUST(DefaultCUDAGenerator(device_id)); + cuda_gen->set_current_seed(seed); + } +#endif // WITH_CUDA + return Maybe::Ok(); +} + Maybe MakeGenerator(const std::string& device, int device_index) { if (device == "cpu") { return MakeCPUGenerator(); diff --git a/oneflow/core/framework/random_generator.h b/oneflow/core/framework/random_generator.h index f1129157011..f71e1f38e47 100644 --- a/oneflow/core/framework/random_generator.h +++ b/oneflow/core/framework/random_generator.h @@ -66,7 +66,10 @@ class Generator final { std::shared_ptr impl_; }; -Maybe ManualSeed(uint64_t seed); +Maybe ManualSeed(uint64_t seed); + +Maybe ManualSeed(uint64_t seed, const std::string& device, int device_index = -1); +Maybe ManualSeed(uint64_t seed, DeviceType device, int device_index = -1); Maybe DefaultGenerator(const std::string& device, int device_index = -1); Maybe DefaultGenerator(DeviceType device, int device_index = -1); @@ -84,6 +87,7 @@ Maybe MakeCPUGenerator(); Maybe DefaultCUDAGenerator(int device_index = -1); Maybe MakeCUDAGenerator(); #endif // WITH_CUDA +Maybe ManualSeedAllCudaGenerator(uint64_t seed); } // namespace one } // namespace oneflow diff --git a/oneflow/core/framework/random_generator_impl.cpp b/oneflow/core/framework/random_generator_impl.cpp index 385b0aca90b..1e16e5e052d 100644 --- a/oneflow/core/framework/random_generator_impl.cpp +++ b/oneflow/core/framework/random_generator_impl.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/random_generator_impl.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor_util.h" @@ -23,6 +24,7 @@ limitations under the License. #include "oneflow/core/vm/virtual_machine.h" #include "oneflow/core/register/ofblob.h" #include "oneflow/core/vm/vm_util.h" +#include "oneflow/core/platform/include/pthread_fork.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" #include @@ -224,7 +226,10 @@ void AutoGeneratorImpl::set_current_seed(uint64_t seed) { CHECK_JUST(CPUSynchronize()); std::lock_guard lock(mutex_); seed_ = seed; - for (const auto& it : generators_) { it.second->set_current_seed(seed); } + for (const auto& it : generators_) { + if (unlikely(pthread_fork::IsForkedSubProcess() && it.first.device_type == kCUDA)) { continue; } + it.second->set_current_seed(seed); + } } struct AutoGeneratorState { diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 5cd8bacc570..f9eb2a766f4 100755 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -286,7 +286,9 @@ def atexit_hook(hook): from oneflow.framework.generator import create_generator as Generator from oneflow.framework.generator import ( default_generator, + seed, manual_seed, + initial_seed, get_rng_state, set_rng_state, ) diff --git a/python/oneflow/cuda/__init__.py b/python/oneflow/cuda/__init__.py index 02fff3a814a..69c32a113d5 100644 --- a/python/oneflow/cuda/__init__.py +++ b/python/oneflow/cuda/__init__.py @@ -33,3 +33,38 @@ def device_count() -> int: def current_device() -> int: r"""Returns local rank as device index.""" return flow._oneflow_internal.GetCudaDeviceIndex() + + +def manual_seed_all(seed) -> None: + r"""The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.cuda.manual_seed_all.html?highlight=manual_seed_all + + Sets the seed for generating random numbers on all GPUs. + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + + Args: + seed (int): The desired seed. + """ + seed = int(seed) + flow._oneflow_internal.ManualSeedAllCudaGenerator(seed) + + +def manual_seed(seed: int) -> None: + r"""The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.cuda.manual_seed.html?highlight=manual_seed + + Sets the seed for generating random numbers for the current GPU. + It's safe to call this function if CUDA is not available; in that + case, it is silently ignored. + + Args: + seed (int): The desired seed. + + .. warning:: + If you are working with a multi-GPU model, this function is insufficient + to get determinism. To seed all GPUs, use :func:`manual_seed_all`. + """ + seed = int(seed) + idx = current_device() + flow._oneflow_internal.manual_seed(seed, "cuda", idx) diff --git a/python/oneflow/framework/docstr/constant.py b/python/oneflow/framework/docstr/constant.py index 9e81ab0c15b..6d84b58cc00 100644 --- a/python/oneflow/framework/docstr/constant.py +++ b/python/oneflow/framework/docstr/constant.py @@ -67,7 +67,7 @@ """ new_ones(x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor - Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same torch.dtype and torch.device as this tensor. + Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same oneflow.dtype and oneflow.device as this tensor. Args: size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor. diff --git a/python/oneflow/framework/docstr/math_ops.py b/python/oneflow/framework/docstr/math_ops.py index 2be3808110f..1958e6cdd0b 100644 --- a/python/oneflow/framework/docstr/math_ops.py +++ b/python/oneflow/framework/docstr/math_ops.py @@ -1460,9 +1460,9 @@ r""" Splits input, a tensor with one or more dimensions, into multiple tensors horizontally according to indices_or_sections. Each split is a view of input. - If input is one dimensional this is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + If input is one dimensional this is equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=0) (the split dimension is zero), and if input has two or more dimensions it’s equivalent to calling - torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), except that if indices_or_sections + oneflow.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), except that if indices_or_sections is an integer it must evenly divide the split dimension or a runtime error will be thrown. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.hsplit.html#torch.hsplit @@ -1503,7 +1503,7 @@ r""" Splits input, a tensor with two or more dimensions, into multiple tensors vertically according to indices_or_sections. Each split is a view of input. - This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is 0), + This is equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=0) (the split dimension is 0), except that if indices_or_sections is an integer it must evenly divide the split dimension or a runtime error will be thrown. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.vsplit.html#torch.vsplit diff --git a/python/oneflow/framework/generator.py b/python/oneflow/framework/generator.py index edf37756426..8dea64e71d6 100644 --- a/python/oneflow/framework/generator.py +++ b/python/oneflow/framework/generator.py @@ -23,8 +23,43 @@ def create_generator(device=None): return oneflow._oneflow_internal.create_generator(device) -def manual_seed(seed): +def seed() -> int: + r"""The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.seed.html + + Sets the seed for generating random numbers to a non-deterministic + random number. Returns a 64 bit number used to seed the RNG. + """ + seed = default_generator.seed() oneflow._oneflow_internal.manual_seed(seed) + return seed + + +def manual_seed(seed): + r"""The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.manual_seed.html + + Sets the seed for generating random numbers. Returns a + `oneflow.Generator` object. + + Args: + seed (int): The desired seed. Value must be within the inclusive range + `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError + is raised. Negative inputs are remapped to positive values with the formula + `0xffff_ffff_ffff_ffff + seed`. + """ + seed = int(seed) + return oneflow._oneflow_internal.manual_seed(seed) + + +def initial_seed() -> int: + r"""The documentation is referenced from: + https://pytorch.org/docs/stable/_modules/torch/random.html + + Returns the initial seed for generating random numbers as a + Python `long`. + """ + return default_generator.initial_seed() def _getstate(self): @@ -37,15 +72,26 @@ def _setstate(self, state_dict): def get_rng_state(): - """ - returns the state of the default random number generator + r"""The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.get_rng_state.html + + Sets the random number generator state. + + .. note: This function only works for CPU. For CUDA, please use + oneflow.manual_seed(seed), which works for both CPU and CUDA. + + Args: + new_state (oneflow.ByteTensor): The desired state """ return oneflow.default_generator.get_state() def set_rng_state(state): - """ - sets the state of the default random number generator to the given state + """The documentation is referenced from: + https://pytorch.org/docs/stable/generated/torch.set_rng_state.html + + + Returns the random number generator state as a `oneflow.ByteTensor`. """ return oneflow.default_generator.set_state(state) diff --git a/python/oneflow/nn/modules/empty.py b/python/oneflow/nn/modules/empty.py index 39b3d9d2a70..51ae2145737 100644 --- a/python/oneflow/nn/modules/empty.py +++ b/python/oneflow/nn/modules/empty.py @@ -39,7 +39,7 @@ def empty_op( size (int... or oneflow.Size): Defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size. dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``. - device (torch.device, optional): The desired device of returned local tensor. If None, uses the + device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the current device. placement (flow.placement, optional): The desired device of returned global tensor. If None, will construct local tensor. diff --git a/python/oneflow/test/misc/test_manual_seed_api.py b/python/oneflow/test/misc/test_manual_seed_api.py new file mode 100644 index 00000000000..642c68e9043 --- /dev/null +++ b/python/oneflow/test/misc/test_manual_seed_api.py @@ -0,0 +1,59 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import unittest + +import numpy as np +import oneflow as flow + +import oneflow.unittest + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n2d() +class TestManualSeedApi(flow.unittest.TestCase): + def test_cuda_manual_seed_all(test_case): + flow.cuda.manual_seed_all(20) + x = flow.randn(2, 4, device="cuda:0") + y = flow.randn(2, 4, device="cuda:1") + test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) + + def test_cuda_manual_seed(test_case): + flow.cuda.manual_seed(30) + device = flow.device("cuda", flow.cuda.current_device()) + x = flow.randn(2, 4, device=device) + tensor_list = [flow.zeros((2, 4), dtype=flow.int32) for _ in range(2)] + flow.comm.all_gather(tensor_list, x) + test_case.assertTrue( + np.allclose(tensor_list[0].numpy(), tensor_list[1].numpy()) + ) + + def test_manual_seed(test_case): + flow.manual_seed(40) + x = flow.randn(2, 4, device="cuda:0") + y = flow.randn(2, 4, device="cuda:1") + test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) + + def test_set_get_rng_state(test_case): + x = flow.ByteTensor(5000) + flow.set_rng_state(x) + y = flow.get_rng_state() + test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) + + +if __name__ == "__main__": + unittest.main()