Skip to content

Commit

Permalink
Fix rand ops that have the same seed in graph and have different seed…
Browse files Browse the repository at this point in the history
… in checkpointing (#9941)

fixed: Oneflow-Inc/OneTeam#1926

修复关于 random op 系列的 bug:

- 一些 random op kernel 在 CreateOpKernelState 时,直接 MakeGenerator,使用的是
default_rng_seed_val,所有的 random op 的 generator 使用的相同的
seed,那么会产生相同的结果。例如:transformer 所有的 layer 的 dropout 都产生一样的随机结果,很可能影响最终精度。
- 一些 random op kernel 在 CreateOpKernelState 时,没考虑 split 时,不同 rank
上应该有不同的结果(broadcast 需要相同的结果),应该通过 GetOpKernelRandomSeedInCurrentRank 为不同
rank 的 generator 设置相同的 seed。
- 在 checkpoint activation 开启后,要求前后向对应的 random op(例如 dropout)具有相同的 random
state,这样才能使前后向中相同的 dropout 产生一致的结果。所以这里要求前后向对应的 dropout 拥有相同的 seed,可以通过为
dropout 添加 attr seed 来完成。

eager 所有 random op 使用相同的 generator(有些 op 容许传入自定义的 generator)。但 lazy
不同,所有的 random op kernel 都会创建自己的 generator,只能通过 seed 来控制随机结果。同一张图内的不同
random op 需要不同 seed,split 的不同 rank 需要不同 seed,checkpoint activation 的前后向
random op 需要相同 seed,broadcast 的不同 rank 需要相同 seed。

---------

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
leaves-zwx and mergify[bot] authored Mar 11, 2023
1 parent 63060cc commit d3712f1
Show file tree
Hide file tree
Showing 19 changed files with 1,026 additions and 241 deletions.
23 changes: 14 additions & 9 deletions oneflow/core/functional/impl/activation_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ limitations under the License.
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/unary_functor.h"
#include "oneflow/core/functional/impl/binary_functor.h"
#include "oneflow/core/functional/sequence_function.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/mutable_attr_map.h"
#include "oneflow/core/framework/op_builder.h"
Expand All @@ -28,10 +30,10 @@ limitations under the License.
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_util.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/functional/sequence_function.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/user/kernels/random_seed_util.h"

namespace oneflow {
namespace one {
Expand Down Expand Up @@ -435,9 +437,6 @@ class GumbelSoftmaxFunctor {
const int64_t num_axes = in_shape->NumAxes();

const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("tau", "seed", "hard");
attrs.SetAllAttrs(tau, static_cast<int64_t>(gen->current_seed()), hard);

auto random_tensor =
JUST(functional::Rand(*in_shape.get(), dtype, device, gen, /*requires_grad=*/false));
auto gumbel_noise_tensor = JUST(functional::ScalarSub(
Expand Down Expand Up @@ -542,15 +541,21 @@ class RReluFunctor {
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& lower,
const float& upper, bool training, bool inplace) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("lower", "upper", "training");
attrs.SetAllAttrs(lower, upper, training);
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(2);
if (!training) { return JUST(functional::LeakyRelu(x, ((lower + upper) / 2), inplace)); }

auto gen = JUST(
GetGeneratorForLazyOrGlobal(JUST(one::DefaultAutoGenerator()), LazyMode::is_enabled(), x));
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("seed", "lower", "upper", "training");
attrs.SetAllAttrs(static_cast<int64_t>(gen->current_seed()), lower, upper, training);
const auto& state = std::make_shared<DistributionKernelState>(gen);

OpExprInterpContext ctx(attrs, state);
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(2);
if (inplace) {
JUST(CheckInplaceValid(x));
outputs->at(0) = x;
}
JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), attrs));
JUST(OpInterpUtil::Dispatch(*op_, {x}, outputs.get(), ctx));
return outputs->at(0);
}

Expand Down
51 changes: 30 additions & 21 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#include "fmt/core.h"
#include "oneflow/core/framework/mutable_attr_map.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/tensor_util.h"
Expand All @@ -25,8 +24,11 @@ limitations under the License.
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/user/kernels/random_mask_like_kernel.h"
#include "oneflow/user/kernels/dropout_kernel.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/user/kernels/random_seed_util.h"

#include "oneflow/core/common/container_util.h"
#include "fmt/core.h"

namespace oneflow {
namespace one {
Expand Down Expand Up @@ -795,8 +797,6 @@ class FusedMatmulBiasAddReluDropoutFunctor {
*/
const auto& x_shape = x->shape();
k = x_shape->At(1);
const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
const auto& dropout_state = std::make_shared<FusedDropoutKernelState>(gen);
for (int64_t i = 0; i < weight_size; i++) {
CHECK_GE_OR_RETURN(dropout_rate_list[i], 0.0f)
<< Error::RuntimeError() << "Dropout rate should be >= 0.0";
Expand All @@ -815,22 +815,28 @@ class FusedMatmulBiasAddReluDropoutFunctor {
k = n;
}

auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));

#if CUDA_VERSION >= 11060
DeviceType device_type{};
if (x->is_global()) {
device_type = JUST(x->parallel_desc())->device_type();
} else {
device_type = JUST(x->device())->enum_type();
}

if ((device_type == DeviceType::kCUDA) && (weight_size <= kMaxInputCount)
&& (!ParseBooleanFromEnv("ONEFLOW_FUNCTOR_DISABLE_FUSED_MLP", false))) {
TensorTuple input(2 * weight_size + 1);
input[0] = x;
std::copy(weights.begin(), weights.end(), input.begin() + 1);
std::copy(biases.begin(), biases.end(), input.begin() + 1 + weight_size);
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("skip_final_activation", "dropout_rate_list");
attrs.SetAllAttrs(skip_final_activation, dropout_rate_list);

gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));
auto& attrs =
THREAD_CACHED_MUTABLE_ATTR_MAP("skip_final_activation", "seed", "dropout_rate_list");
attrs.SetAllAttrs(skip_final_activation, static_cast<int64_t>(gen->current_seed()),
dropout_rate_list);
const auto& dropout_state = std::make_shared<FusedDropoutKernelState>(gen);
return OpInterpUtil::Dispatch<Tensor>(*fused_op_[weight_size], input,
OpExprInterpContext(attrs, dropout_state));
}
Expand Down Expand Up @@ -2816,25 +2822,27 @@ class DropoutFunctor {
JUST(CheckInplaceValid(x));
(*outputs)[0] = x;
}
const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));

auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));
auto& dropout_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed");
dropout_attrs.SetAllAttrs(p, static_cast<int64_t>(gen->current_seed()));

const auto& dropout_state = std::make_shared<FusedDropoutKernelState>(gen);
auto& dropout_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate");
dropout_attrs.SetAllAttrs(p);
OpExprInterpContext ctx(dropout_attrs, dropout_state);
if (addend) {
if ((!training) || p == 0.0) {
JUST(OpInterpUtil::Dispatch(*add_op_, {x, JUST(addend)}, outputs.get()));
} else {
outputs->resize(2);
JUST(OpInterpUtil::Dispatch(*dropout_addend_op_, {x, JUST(addend)}, outputs.get(),
OpExprInterpContext(dropout_attrs, dropout_state)));
JUST(OpInterpUtil::Dispatch(*dropout_addend_op_, {x, JUST(addend)}, outputs.get(), ctx));
}
} else {
if (!training || p == 0.0) {
return x;
} else {
outputs->resize(2);
JUST(OpInterpUtil::Dispatch(*dropout_op_, {x}, outputs.get(),
OpExprInterpContext(dropout_attrs, dropout_state)));
JUST(OpInterpUtil::Dispatch(*dropout_op_, {x}, outputs.get(), ctx));
}
}
return (*outputs)[0];
Expand Down Expand Up @@ -3282,11 +3290,11 @@ class FusedScaleTrilSoftmaxMaskScaleFunctor {
const int64_t diagonal, const float tril_scale_value,
const float tril_fill_value,
const Optional<one::Generator>& generator) const {
const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));
auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed");
random_mask_like_attrs.SetAllAttrs(p, static_cast<int64_t>(gen->current_seed()));
const auto& random_mask_like_state = std::make_shared<RandomMaskLikeKernelState>(gen);

const auto& mask = JUST(OpInterpUtil::Dispatch<Tensor>(
*random_mask_like_op_, {x},
OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)));
Expand Down Expand Up @@ -3387,7 +3395,8 @@ class FusedBiasAddDropoutFunctor {
axis_val += num_axes;
}
if (p > 0.0) {
const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), a));
auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed");
random_mask_like_attrs.SetAllAttrs(p, static_cast<int64_t>(gen->current_seed()));
const auto& random_mask_like_state = std::make_shared<RandomMaskLikeKernelState>(gen);
Expand Down Expand Up @@ -3628,11 +3637,11 @@ class FusedScaleMaskSoftmaxDropoutFunctor {
const Optional<one::Generator>& generator) const {
float rate = p;
if (!training) rate = 0.0;
const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));
auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed");
random_mask_like_attrs.SetAllAttrs(rate, static_cast<int64_t>(gen->current_seed()));
const auto& random_mask_like_state = std::make_shared<RandomMaskLikeKernelState>(gen);

const auto& dropout_mask = JUST(OpInterpUtil::Dispatch<Tensor>(
*random_mask_like_op_, {x},
OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)));
Expand Down Expand Up @@ -3679,11 +3688,11 @@ class FusedBiasAddScaleMaskSoftmaxDropoutFunctor {
const Optional<one::Generator>& generator) const {
float rate = p;
if (!training) rate = 0.0;
const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
gen = JUST(GetGeneratorForLazyOrGlobal(gen, LazyMode::is_enabled(), x));
auto& random_mask_like_attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("rate", "seed");
random_mask_like_attrs.SetAllAttrs(rate, static_cast<int64_t>(gen->current_seed()));
const auto& random_mask_like_state = std::make_shared<RandomMaskLikeKernelState>(gen);

const auto& dropout_mask = JUST(OpInterpUtil::Dispatch<Tensor>(
*random_mask_op_, {x},
OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)));
Expand Down
Loading

0 comments on commit d3712f1

Please sign in to comment.