Skip to content

Commit

Permalink
add_manual_seed_all_api (#7957)
Browse files Browse the repository at this point in the history
* add_manual_seed_all_api

* Update conf.py

* refine

* add test case

* auto format by CI

* Update random_generator.cpp

* auto format by CI

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 9, 2022
1 parent bc0a9b3 commit a012cb8
Show file tree
Hide file tree
Showing 15 changed files with 240 additions and 16 deletions.
2 changes: 2 additions & 0 deletions docs/source/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ ONEFLOW.CUDA
:members: is_available,
device_count,
current_device,
manual_seed_all,
manual_seed,
HalfTensor,
FloatTensor,
DoubleTensor,
Expand Down
5 changes: 5 additions & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ oneflow
CharTensor,
IntTensor,
LongTensor,
seed,
manual_seed,
initial_seed,
get_rng_state,
set_rng_state,
isnan,
isinf

Expand Down
23 changes: 21 additions & 2 deletions oneflow/api/python/framework/random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/framework/tensor.h"
Expand All @@ -34,21 +35,39 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def(py::init([](const std::string& device_tag) {
return CreateGenerator(device_tag).GetPtrOrThrow();
}))
.def("manual_seed", &one::Generator::set_current_seed)
.def("manual_seed",
[](const std::shared_ptr<one::Generator>& generator,
const py::object& seed) -> Maybe<void> {
int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr()));
generator->set_current_seed(seed_val);
return Maybe<void>::Ok();
})
.def("initial_seed", &one::Generator::current_seed)
.def("seed", &one::Generator::seed)
.def_property_readonly("device", &one::Generator::device)
.def("get_state", &one::Generator::GetState)
.def("set_state", &one::Generator::SetState);

m.def("manual_seed", [](uint64_t seed) { return one::ManualSeed(seed); });
m.def("manual_seed", [](const py::object& seed) -> Maybe<one::Generator> {
int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr()));
return one::ManualSeed(seed_val);
});
m.def("manual_seed",
[](const py::object& seed, const std::string& device, int device_index) -> Maybe<void> {
int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr()));
return one::ManualSeed(seed_val, device, device_index);
});
m.def("create_generator", &CreateGenerator);
m.def("default_generator", [](const std::string& device_tag) -> Maybe<one::Generator> {
std::string device_name = "";
int device_index = -1;
JUST(ParsingDeviceTag(device_tag, &device_name, &device_index));
return one::DefaultGenerator(device_name, device_index);
});
m.def("ManualSeedAllCudaGenerator", [](const py::object& seed) -> Maybe<void> {
int64_t seed_val = JUST(one::functional::PyUnpackLong(seed.ptr()));
return one::ManualSeedAllCudaGenerator(seed_val);
});
}

} // namespace oneflow
9 changes: 9 additions & 0 deletions oneflow/api/python/functional/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ Maybe<OpExpr> PyUnpackOpExpr(PyObject* obj) {
return py::cast<std::shared_ptr<OpExpr>>(handle);
}

// int64_t
Maybe<int64_t> PyUnpackLong(PyObject* py_obj) {
int overflow = -1;
long long val = PyLong_AsLongLongAndOverflow(py_obj, &overflow);
if (val == -1 && PyErr_Occurred()) { return Error::RuntimeError() << "Python exception occurs"; }
if (overflow != 0) { return Error::RuntimeError() << "Overflow when unpacking long"; }
return (int64_t)val;
}

} // namespace functional
} // namespace one
} // namespace oneflow
3 changes: 3 additions & 0 deletions oneflow/api/python/functional/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ Maybe<TensorIndex> PyUnpackTensorIndex(PyObject* obj);
bool PyOpExprCheck(PyObject* obj);
Maybe<OpExpr> PyUnpackOpExpr(PyObject* obj);

// int64_t
Maybe<int64_t> PyUnpackLong(PyObject* py_obj);

} // namespace functional
} // namespace one
} // namespace oneflow
Expand Down
39 changes: 37 additions & 2 deletions oneflow/core/framework/random_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,35 @@ limitations under the License.
namespace oneflow {
namespace one {

Maybe<void> ManualSeed(uint64_t seed) {
JUST(DefaultAutoGenerator())->set_current_seed(seed);
Maybe<Generator> ManualSeed(uint64_t seed) {
const auto& default_auto_generator = JUST(DefaultAutoGenerator());
default_auto_generator->set_current_seed(seed);
return default_auto_generator;
}

Maybe<void> ManualSeed(uint64_t seed, const std::string& device, int device_index) {
if (device == "cpu") {
JUST(DefaultCPUGenerator())->set_current_seed(seed);
}
#ifdef WITH_CUDA
else if (device == "cuda") {
JUST(DefaultCUDAGenerator(device_index))->set_current_seed(seed);
}
#endif // WITH_CUDA
else if (device == "auto") {
JUST(DefaultAutoGenerator())->set_current_seed(seed);
} else {
return Error::RuntimeError() << "Invalid device " << device
<< " for making generator, please make sure the device is one of "
<< PrintAvailableDevices();
}
return Maybe<void>::Ok();
}

Maybe<void> ManualSeed(uint64_t seed, DeviceType device, int device_index) {
return ManualSeed(seed, *JUST(DeviceTag4DeviceType(device)), device_index);
}

namespace detail {

uint64_t GetNonDeterministicRandom() {
Expand Down Expand Up @@ -100,6 +124,17 @@ Maybe<Generator> MakeCUDAGenerator(int device_index) {
}
#endif // WITH_CUDA

Maybe<void> ManualSeedAllCudaGenerator(uint64_t seed) {
#ifdef WITH_CUDA
static int device_count = GetCudaDeviceCount();
FOR_RANGE(int, device_id, 0, device_count) {
const auto& cuda_gen = JUST(DefaultCUDAGenerator(device_id));
cuda_gen->set_current_seed(seed);
}
#endif // WITH_CUDA
return Maybe<void>::Ok();
}

Maybe<Generator> MakeGenerator(const std::string& device, int device_index) {
if (device == "cpu") {
return MakeCPUGenerator();
Expand Down
6 changes: 5 additions & 1 deletion oneflow/core/framework/random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ class Generator final {
std::shared_ptr<GeneratorImpl> impl_;
};

Maybe<void> ManualSeed(uint64_t seed);
Maybe<Generator> ManualSeed(uint64_t seed);

Maybe<void> ManualSeed(uint64_t seed, const std::string& device, int device_index = -1);
Maybe<void> ManualSeed(uint64_t seed, DeviceType device, int device_index = -1);

Maybe<Generator> DefaultGenerator(const std::string& device, int device_index = -1);
Maybe<Generator> DefaultGenerator(DeviceType device, int device_index = -1);
Expand All @@ -84,6 +87,7 @@ Maybe<Generator> MakeCPUGenerator();
Maybe<Generator> DefaultCUDAGenerator(int device_index = -1);
Maybe<Generator> MakeCUDAGenerator();
#endif // WITH_CUDA
Maybe<void> ManualSeedAllCudaGenerator(uint64_t seed);

} // namespace one
} // namespace oneflow
Expand Down
7 changes: 6 additions & 1 deletion oneflow/core/framework/random_generator_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ limitations under the License.
#include "oneflow/core/framework/random_generator_impl.h"

#include "oneflow/core/common/util.h"
#include "oneflow/core/common/cpp_attribute.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/vm/virtual_machine.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/platform/include/pthread_fork.h"
#ifdef WITH_CUDA
#include "oneflow/core/device/cuda_util.h"
#include <cuda.h>
Expand Down Expand Up @@ -224,7 +226,10 @@ void AutoGeneratorImpl::set_current_seed(uint64_t seed) {
CHECK_JUST(CPUSynchronize());
std::lock_guard<std::mutex> lock(mutex_);
seed_ = seed;
for (const auto& it : generators_) { it.second->set_current_seed(seed); }
for (const auto& it : generators_) {
if (unlikely(pthread_fork::IsForkedSubProcess() && it.first.device_type == kCUDA)) { continue; }
it.second->set_current_seed(seed);
}
}

struct AutoGeneratorState {
Expand Down
2 changes: 2 additions & 0 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def atexit_hook(hook):
from oneflow.framework.generator import create_generator as Generator
from oneflow.framework.generator import (
default_generator,
seed,
manual_seed,
initial_seed,
get_rng_state,
set_rng_state,
)
Expand Down
35 changes: 35 additions & 0 deletions python/oneflow/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,38 @@ def device_count() -> int:
def current_device() -> int:
r"""Returns local rank as device index."""
return flow._oneflow_internal.GetCudaDeviceIndex()


def manual_seed_all(seed) -> None:
r"""The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.cuda.manual_seed_all.html?highlight=manual_seed_all
Sets the seed for generating random numbers on all GPUs.
It's safe to call this function if CUDA is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
"""
seed = int(seed)
flow._oneflow_internal.ManualSeedAllCudaGenerator(seed)


def manual_seed(seed: int) -> None:
r"""The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.cuda.manual_seed.html?highlight=manual_seed
Sets the seed for generating random numbers for the current GPU.
It's safe to call this function if CUDA is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
.. warning::
If you are working with a multi-GPU model, this function is insufficient
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
"""
seed = int(seed)
idx = current_device()
flow._oneflow_internal.manual_seed(seed, "cuda", idx)
2 changes: 1 addition & 1 deletion python/oneflow/framework/docstr/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"""
new_ones(x, size=None, dtype=None, device=None, placement=None, sbp=None, requires_grad=False) -> Tensor
Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same torch.dtype and torch.device as this tensor.
Returns a Tensor of size size filled with 1. By default, the returned Tensor has the same oneflow.dtype and oneflow.device as this tensor.
Args:
size (int...): a list, tuple, or flow.Size of integers defining the shape of the output tensor.
Expand Down
6 changes: 3 additions & 3 deletions python/oneflow/framework/docstr/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,9 +1460,9 @@
r"""
Splits input, a tensor with one or more dimensions, into multiple tensors horizontally according to indices_or_sections.
Each split is a view of input.
If input is one dimensional this is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0)
If input is one dimensional this is equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=0)
(the split dimension is zero), and if input has two or more dimensions it’s equivalent to calling
torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), except that if indices_or_sections
oneflow.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), except that if indices_or_sections
is an integer it must evenly divide the split dimension or a runtime error will be thrown.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.hsplit.html#torch.hsplit
Expand Down Expand Up @@ -1503,7 +1503,7 @@
r"""
Splits input, a tensor with two or more dimensions, into multiple tensors vertically according to indices_or_sections.
Each split is a view of input.
This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is 0),
This is equivalent to calling oneflow.tensor_split(input, indices_or_sections, dim=0) (the split dimension is 0),
except that if indices_or_sections is an integer it must evenly divide the split dimension or a runtime error will be thrown.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.vsplit.html#torch.vsplit
Expand Down
56 changes: 51 additions & 5 deletions python/oneflow/framework/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,43 @@ def create_generator(device=None):
return oneflow._oneflow_internal.create_generator(device)


def manual_seed(seed):
def seed() -> int:
r"""The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.seed.html
Sets the seed for generating random numbers to a non-deterministic
random number. Returns a 64 bit number used to seed the RNG.
"""
seed = default_generator.seed()
oneflow._oneflow_internal.manual_seed(seed)
return seed


def manual_seed(seed):
r"""The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.manual_seed.html
Sets the seed for generating random numbers. Returns a
`oneflow.Generator` object.
Args:
seed (int): The desired seed. Value must be within the inclusive range
`[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
is raised. Negative inputs are remapped to positive values with the formula
`0xffff_ffff_ffff_ffff + seed`.
"""
seed = int(seed)
return oneflow._oneflow_internal.manual_seed(seed)


def initial_seed() -> int:
r"""The documentation is referenced from:
https://pytorch.org/docs/stable/_modules/torch/random.html
Returns the initial seed for generating random numbers as a
Python `long`.
"""
return default_generator.initial_seed()


def _getstate(self):
Expand All @@ -37,15 +72,26 @@ def _setstate(self, state_dict):


def get_rng_state():
"""
returns the state of the default random number generator
r"""The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.get_rng_state.html
Sets the random number generator state.
.. note: This function only works for CPU. For CUDA, please use
oneflow.manual_seed(seed), which works for both CPU and CUDA.
Args:
new_state (oneflow.ByteTensor): The desired state
"""
return oneflow.default_generator.get_state()


def set_rng_state(state):
"""
sets the state of the default random number generator to the given state
"""The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.set_rng_state.html
Returns the random number generator state as a `oneflow.ByteTensor`.
"""

return oneflow.default_generator.set_state(state)
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/nn/modules/empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def empty_op(
size (int... or oneflow.Size): Defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple or oneflow.Size.
dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``.
device (torch.device, optional): The desired device of returned local tensor. If None, uses the
device (oneflow.device, optional): The desired device of returned local tensor. If None, uses the
current device.
placement (flow.placement, optional): The desired device of returned global tensor. If None, will
construct local tensor.
Expand Down
Loading

0 comments on commit a012cb8

Please sign in to comment.