From b6becb9c88e7bd1627460f44a918fb2b39bdc534 Mon Sep 17 00:00:00 2001 From: guo ran <360112263@qq.com> Date: Wed, 20 Apr 2022 11:41:40 +0800 Subject: [PATCH] OneEmbedding update kernel add adagrad (#7999) * column to table * multi column * fix typo and use model update util * fix import * add doc * refine * one_embedding update use model_update kernel util * add doc * add adagrad * refine * revert update * refine * refine * support for adagrad * fix tidy * set state_initializer in pass * state_initializer in pass * fix tidy * skip predict_job_has_optimizer_state prefetch, address review * add test sgd * add adam test * add adagrad test * refine * add l1 l2 * address review * skip if cpu * Update python/oneflow/test/modules/test_one_embedding_adam.py * Update python/oneflow/test/modules/test_one_embedding_sgd.py * Update python/oneflow/test/modules/test_one_embedding_adagrad.py Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- oneflow/core/functional/functional_api.yaml | 12 ++ oneflow/core/functional/impl/nn_functor.cpp | 161 ++++++++++++++ .../replace_embedding_ops_pass.cpp | 27 ++- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 34 +++ .../user/kernels/model_update_kernel_util.h | 3 +- .../kernels/one_embedding_update_kernels.cu | 119 ++++++++++- oneflow/user/ops/one_embedding_ops.cpp | 32 +++ .../modules/test_one_embedding_adagrad.py | 155 ++++++++++++++ .../test/modules/test_one_embedding_adam.py | 200 ++++++++++++++++++ .../test/modules/test_one_embedding_sgd.py | 146 +++++++++++++ 10 files changed, 878 insertions(+), 11 deletions(-) create mode 100644 python/oneflow/test/modules/test_one_embedding_adagrad.py create mode 100644 python/oneflow/test/modules/test_one_embedding_adam.py create mode 100644 python/oneflow/test/modules/test_one_embedding_sgd.py diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index cfb9a11a009..1bdee9b24f0 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2086,6 +2086,18 @@ signature: "TensorTuple (Tensor keys, Tensor values=None, Int32 num_tables) => OneEmbeddingUniqueKeyValuePair" bind_python: True +- name: "one_embedding_sgd_update" + signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate, Tensor down_scale_by_tensor, Tensor skip_if, Double scale, Float weight_decay, Float momentum) => OneEmbeddingSgdUpdate" + bind_python: True + +- name: "one_embedding_adam_update" + signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate, Tensor down_scale_by_tensor, Tensor skip_if, Tensor bias_correction1=None, Tensor bias_correction2=None, Double scale=1.0, Float weight_decay=0.0, Float beta1=0.9, Float beta2=0.999, Float epsilon=0, Bool do_bias_correction=True) => OneEmbeddingAdamUpdate" + bind_python: True + +- name: "one_embedding_adagrad_update" + signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate, Tensor down_scale_by_tensor, Tensor skip_if, Tensor train_step, Double scale=1.0, Float weight_decay=0.0, Float lr_decay=0.0, Float epsilon=0) => OneEmbeddingAdagradUpdate" + bind_python: True + - name: "einsum" signature: "Tensor (String equation, TensorTuple operands) => EinSum" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 937d68969cf..de74cc60d0c 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -2535,6 +2535,164 @@ class OneEmbeddingUniqueKeyValuePairFunctor { std::shared_ptr op_no_input_value_; }; +class OneEmbeddingSgdUpdateFunctor { + public: + OneEmbeddingSgdUpdateFunctor() { + // This functor is just for unittest + sgd_op_ = CHECK_JUST(one::OpBuilder("sgd_embedding_update") + .Input("num_unique_ids") + .Input("unique_embeddings") + .Input("embedding_grad") + .Input("learning_rate") + .Input("down_scale_by_tensor") + .Input("skip_if") + .Output("updated_unique_embeddings") + .Build()); + momentum_op_ = CHECK_JUST(one::OpBuilder("momentum_embedding_update") + .Input("num_unique_ids") + .Input("unique_embeddings") + .Input("embedding_grad") + .Input("learning_rate") + .Input("down_scale_by_tensor") + .Input("skip_if") + .Output("updated_unique_embeddings") + .Build()); + } + + Maybe operator()(const std::shared_ptr& num_unique_ids, + const std::shared_ptr& unique_embeddings, + const std::shared_ptr& embedding_grad, + const std::shared_ptr& learning_rate, + const std::shared_ptr& down_scale_by_tensor, + const std::shared_ptr& skip_if, const double scale, + const float weight_decay, const float momentum) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + if (momentum == 0) { + return OpInterpUtil::Dispatch(*sgd_op_, + {num_unique_ids, unique_embeddings, embedding_grad, + learning_rate, down_scale_by_tensor, skip_if}, + attrs); + } else { + JUST(attrs.SetAttr("beta", momentum)); + return OpInterpUtil::Dispatch(*momentum_op_, + {num_unique_ids, unique_embeddings, embedding_grad, + learning_rate, down_scale_by_tensor, skip_if}, + attrs); + } + } + + private: + std::shared_ptr sgd_op_; + std::shared_ptr momentum_op_; +}; + +class OneEmbeddingAdamUpdateFunctor { + public: + OneEmbeddingAdamUpdateFunctor() { + // This functor is just for unittest + no_bias_correction_op_ = CHECK_JUST(one::OpBuilder("adam_embedding_update") + .Input("num_unique_ids") + .Input("unique_embeddings") + .Input("embedding_grad") + .Input("learning_rate") + .Input("down_scale_by_tensor") + .Input("skip_if") + .Output("updated_unique_embeddings") + .Build()); + do_bias_correction_op_ = CHECK_JUST(one::OpBuilder("adam_embedding_update") + .Input("num_unique_ids") + .Input("unique_embeddings") + .Input("embedding_grad") + .Input("learning_rate") + .Input("down_scale_by_tensor") + .Input("skip_if") + .Input("bias_correction1") + .Input("bias_correction2") + .Output("updated_unique_embeddings") + .Build()); + } + + Maybe operator()(const std::shared_ptr& num_unique_ids, + const std::shared_ptr& unique_embeddings, + const std::shared_ptr& embedding_grad, + const std::shared_ptr& learning_rate, + const std::shared_ptr& down_scale_by_tensor, + const std::shared_ptr& skip_if, + const Optional& bias_correction1, + const Optional& bias_correction2, const double scale, + const float weight_decay, const float beta1, const float beta2, + const float epsilon, const bool do_bias_correction) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + JUST(attrs.SetAttr("beta1", beta1)); + JUST(attrs.SetAttr("beta2", beta2)); + JUST(attrs.SetAttr("epsilon", epsilon)); + JUST(attrs.SetAttr("do_bias_correction", do_bias_correction)); + if (do_bias_correction) { + CHECK(bias_correction1); + CHECK(bias_correction2); + return OpInterpUtil::Dispatch( + *do_bias_correction_op_, + {num_unique_ids, unique_embeddings, embedding_grad, learning_rate, down_scale_by_tensor, + skip_if, JUST(bias_correction1), JUST(bias_correction2)}, + attrs); + } else { + return OpInterpUtil::Dispatch(*no_bias_correction_op_, + {num_unique_ids, unique_embeddings, embedding_grad, + learning_rate, down_scale_by_tensor, skip_if}, + attrs); + } + } + + private: + std::shared_ptr no_bias_correction_op_; + std::shared_ptr do_bias_correction_op_; +}; + +class OneEmbeddingAdagradUpdateFunctor { + public: + OneEmbeddingAdagradUpdateFunctor() { + // This functor is just for unittest + op_ = CHECK_JUST(one::OpBuilder("adagrad_embedding_update") + .Input("num_unique_ids") + .Input("unique_embeddings") + .Input("embedding_grad") + .Input("learning_rate") + .Input("down_scale_by_tensor") + .Input("skip_if") + .Input("train_step") + .Output("updated_unique_embeddings") + .Build()); + } + + Maybe operator()(const std::shared_ptr& num_unique_ids, + const std::shared_ptr& unique_embeddings, + const std::shared_ptr& embedding_grad, + const std::shared_ptr& learning_rate, + const std::shared_ptr& down_scale_by_tensor, + const std::shared_ptr& skip_if, + const std::shared_ptr& train_step, const double scale, + const float weight_decay, const float lr_decay, + const float epsilon) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + JUST(attrs.SetAttr("lr_decay", lr_decay)); + JUST(attrs.SetAttr("epsilon", epsilon)); + return OpInterpUtil::Dispatch( + *op_, + {num_unique_ids, unique_embeddings, embedding_grad, learning_rate, down_scale_by_tensor, + skip_if, train_step}, + attrs); + } + + private: + std::shared_ptr op_; +}; + class RocAucScoreFunctor { public: RocAucScoreFunctor() { @@ -2628,6 +2786,9 @@ ONEFLOW_FUNCTION_LIBRARY(m) { "OneEmbeddingEmbeddingGradientShuffle"); m.add_functor("OneEmbeddingLookup"); m.add_functor("OneEmbeddingUniqueKeyValuePair"); + m.add_functor("OneEmbeddingSgdUpdate"); + m.add_functor("OneEmbeddingAdamUpdate"); + m.add_functor("OneEmbeddingAdagradUpdate"); m.add_functor("RocAucScore"); }; diff --git a/oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp b/oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp index 202521320d8..076019fa7d8 100644 --- a/oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp +++ b/oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp @@ -434,8 +434,8 @@ void MakeConstantInitializerAttr(const int64_t embedding_size, const int64_t lin void BuildEmbeddingUpdate(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder, const ParallelConf& parallel_conf, const int64_t embedding_size, - const int64_t line_size, const std::string& embedding_name, - const OptimizerConf& optimizer_conf, + const int64_t line_size, const float l1, const float l2, + const std::string& embedding_name, const OptimizerConf& optimizer_conf, const user_op::UserOpConfWrapper& embedding_op, const std::string& num_unique_ids_lbn, const std::string& unique_ids_lbn, const std::string& unique_values_lbn, @@ -482,7 +482,12 @@ void BuildEmbeddingUpdate(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* .Input("bias_correction2", bias_correction2_lbn); } } else if (optimizer_conf.has_adagrad_conf()) { - UNIMPLEMENTED(); + const AdagradModelUpdateConf& adagrad_conf = optimizer_conf.adagrad_conf(); + state_constant_init_values.push_back(adagrad_conf.initial_accumulator_value()); + embedding_update_op_builder.OpTypeName("adagrad_embedding_update") + .Input("train_step", train_conf.train_step_lbn()) + .Attr("lr_decay", adagrad_conf.lr_decay()) + .Attr("epsilon", adagrad_conf.epsilon()); } else { UNIMPLEMENTED(); } @@ -493,6 +498,8 @@ void BuildEmbeddingUpdate(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* .Input("embedding_grad", embedding_grad_lbn) .Input("learning_rate", learning_rate_lbn) .Attr("weight_decay", optimizer_conf.weight_decay_conf().weight_decay_rate()) + .Attr("l1", l1) + .Attr("l2", l2) .Output("updated_unique_embeddings"); double scale = GetLossInstanceNumScaleFactor(op_graph, job_builder); if (train_conf.has_dynamic_loss_scale_policy()) { @@ -690,12 +697,24 @@ Maybe ReplaceEmbeddingOps::Apply(const OpGraph& op_graph, JobBuilder* job_ if (found_embedding_optimizer == true) { break; } } CHECK_EQ(found_embedding_optimizer, true); + + const OpNode* shadow_node = op_graph.OpNode4OpName(shadow_op_name); + const VariableOpConf& shadow_variable_conf = shadow_node->op().op_conf().variable_conf(); + float l1 = 0.0; + float l2 = 0.0; + if (shadow_variable_conf.has_regularizer()) { + const RegularizerConf& regularizer_conf = shadow_variable_conf.regularizer(); + if (regularizer_conf.has_l1_l2_conf()) { + l1 = regularizer_conf.l1_l2_conf().l1(); + l2 = regularizer_conf.l1_l2_conf().l2(); + } + } const std::string& learning_rate_lbn = AddScheduleOp(op_graph, job_builder, embedding_optimizer_conf, "System-Train-LearningRate-Scheduler_" + NewUniqueId()); BuildEmbeddingUpdate(ctx, op_graph, job_builder, op_node->parallel_desc().parallel_conf(), - embedding_size, options.LineSize(), options.Name(), + embedding_size, options.LineSize(), l1, l2, options.Name(), embedding_optimizer_conf, embedding_op, num_unique_ids_lbn, unique_ids_lbn, unique_values_lbn, embedding_grad_lbn, learning_rate_lbn, &state_initializer); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 1fa78e50682..b0bc60af1a6 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -9548,6 +9548,8 @@ def OneFlow_SgdEmbeddingUpdateOp : OneFlow_BaseOp<"sgd_embedding_update", [AttrS ); let attrs = (ins DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay ); let same_output_regst_num = 1; @@ -9571,6 +9573,8 @@ def OneFlow_MomentumEmbeddingUpdateOp : OneFlow_BaseOp<"momentum_embedding_updat ); let attrs = (ins DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$beta ); @@ -9597,6 +9601,8 @@ def OneFlow_AdamEmbeddingUpdateOp : OneFlow_BaseOp<"adam_embedding_update", [Att ); let attrs = (ins DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, DefaultValuedAttr:$weight_decay, DefaultValuedAttr:$beta1, DefaultValuedAttr:$beta2, @@ -9610,6 +9616,34 @@ def OneFlow_AdamEmbeddingUpdateOp : OneFlow_BaseOp<"adam_embedding_update", [Att let has_data_type_infer_fn = 1; } +def OneFlow_AdagradEmbeddingUpdateOp : OneFlow_BaseOp<"adagrad_embedding_update", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$num_unique_ids, + OneFlow_Tensor:$unique_embeddings, + OneFlow_Tensor:$embedding_grad, + OneFlow_Tensor:$learning_rate, + OneFlow_Tensor:$train_step, + Optional:$down_scale_by_tensor, + Optional:$skip_if + ); + let output = (outs + OneFlow_Tensor:$updated_unique_embeddings + ); + let attrs = (ins + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$weight_decay, + DefaultValuedAttr:$lr_decay, + DefaultValuedAttr:$epsilon + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_EmbeddingPutOp : OneFlow_BaseOp<"embedding_put", [DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$num_unique_ids, diff --git a/oneflow/user/kernels/model_update_kernel_util.h b/oneflow/user/kernels/model_update_kernel_util.h index 03ae9b819c0..8c98f16e482 100644 --- a/oneflow/user/kernels/model_update_kernel_util.h +++ b/oneflow/user/kernels/model_update_kernel_util.h @@ -113,7 +113,8 @@ struct AdagradUpdateFunctor { CastScaleRegularizeGradientFunctor()(*model_diff, model_val, scale, l1, l2); const T next_sum = *sum + model_diff_t * model_diff_t; *sum = next_sum; - *model = model_val - learning_rate / (sqrt(next_sum) + epsilon) * model_diff_t; + *model = model_val - learning_rate / (sqrt(next_sum) + epsilon) * model_diff_t + - learning_rate * weight_decay * model_val; } }; diff --git a/oneflow/user/kernels/one_embedding_update_kernels.cu b/oneflow/user/kernels/one_embedding_update_kernels.cu index 0cea0e9e47d..8535ea442bb 100644 --- a/oneflow/user/kernels/one_embedding_update_kernels.cu +++ b/oneflow/user/kernels/one_embedding_update_kernels.cu @@ -137,6 +137,42 @@ __global__ void AdamUpdateKernel(const int32_t line_size, const int32_t embeddin } } +template +__global__ void AdagradUpdateKernel(const int64_t line_size, const int64_t embedding_size, T scale, + float l1, float l2, float weight_decay, float lr_decay, + float epsilon, const IDX* num_unique_ids, + const float* learning_rate, const int64_t* train_step_ptr, + const T* down_scale_by_ptr, const int64_t* skip_if, + const G* model_diff, const T* unique_values, + T* updated_unique_values) { + if (skip_if != nullptr && *skip_if != 0) { + const int64_t n = *num_unique_ids * line_size; + CUDA_1D_KERNEL_LOOP(i, n) { + int64_t model_offset; + int64_t sum_offset; + GetMomentumOffset(line_size, embedding_size, i, &model_offset, &sum_offset); + updated_unique_values[model_offset] = unique_values[model_offset]; + updated_unique_values[sum_offset] = unique_values[sum_offset]; + } + } else { + int64_t train_step = *train_step_ptr + 1; + if (down_scale_by_ptr != nullptr) { scale /= *down_scale_by_ptr; } + float learning_rate_val = *learning_rate; + learning_rate_val = learning_rate_val / (1 + (train_step - 1) * lr_decay); + const int64_t n = *num_unique_ids * embedding_size; + CUDA_1D_KERNEL_LOOP(i, n) { + int64_t model_offset; + int64_t sum_offset; + GetMomentumOffset(line_size, embedding_size, i, &model_offset, &sum_offset); + updated_unique_values[model_offset] = unique_values[model_offset]; + updated_unique_values[sum_offset] = unique_values[sum_offset]; + AdagradUpdateFunctor()(model_diff + i, updated_unique_values + model_offset, + updated_unique_values + sum_offset, scale, l1, l2, epsilon, + weight_decay, learning_rate_val); + } + } +} + } // namespace template @@ -159,8 +195,8 @@ class SgdEmbeddingUpdateKernel final : public user_op::OpKernel { const int64_t embedding_size = embedding_grad->shape().At(1); CHECK_EQ(line_size, embedding_size); const auto scale = ctx->Attr("scale"); - const float l1 = 0.0; - const float l2 = 0.0; + const float l1 = ctx->Attr("l1"); + const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); const float* learning_rate_ptr = learning_rate->dptr(); @@ -228,8 +264,8 @@ class MomentumEmbeddingUpdateKernel final : public user_op::OpKernel { const int64_t line_size = unique_embeddings->shape().At(1); const int64_t embedding_size = embedding_grad->shape().At(1); CHECK_EQ(line_size, embedding_size * 2); - const float l1 = 0.0; - const float l2 = 0.0; + const float l1 = ctx->Attr("l1"); + const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const auto beta = ctx->Attr("beta"); const auto scale = ctx->Attr("scale"); @@ -297,8 +333,8 @@ class AdamEmbeddingUpdateKernel final : public user_op::OpKernel { const int64_t embedding_size = embedding_grad->shape().At(1); CHECK_EQ(line_size, embedding_size * 3); - const float l1 = 0.0; - const float l2 = 0.0; + const float l1 = ctx->Attr("l1"); + const float l2 = ctx->Attr("l2"); const auto weight_decay = ctx->Attr("weight_decay"); const auto beta1 = ctx->Attr("beta1"); const auto beta2 = ctx->Attr("beta2"); @@ -356,4 +392,75 @@ class AdamEmbeddingUpdateKernel final : public user_op::OpKernel { OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ADAM_EMBEDDING_UPDATE_KERNEL, FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ) +template +class AdagradEmbeddingUpdateKernel final : public user_op::OpKernel { + public: + AdagradEmbeddingUpdateKernel() = default; + ~AdagradEmbeddingUpdateKernel() override = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* num_unique_ids = ctx->Tensor4ArgNameAndIndex("num_unique_ids", 0); + const user_op::Tensor* unique_embeddings = ctx->Tensor4ArgNameAndIndex("unique_embeddings", 0); + const user_op::Tensor* embedding_grad = ctx->Tensor4ArgNameAndIndex("embedding_grad", 0); + user_op::Tensor* updated_unique_embeddings = + ctx->Tensor4ArgNameAndIndex("updated_unique_embeddings", 0); + CHECK_EQ(unique_embeddings->shape().NumAxes(), 2); + CHECK_EQ(embedding_grad->shape().NumAxes(), 2); + const int64_t num_keys = unique_embeddings->shape().At(0); + const int64_t line_size = unique_embeddings->shape().At(1); + const int64_t embedding_size = embedding_grad->shape().At(1); + CHECK_EQ(line_size, embedding_size * 2); + + const float l1 = ctx->Attr("l1"); + const float l2 = ctx->Attr("l2"); + const auto weight_decay = ctx->Attr("weight_decay"); + const auto lr_decay = ctx->Attr("lr_decay"); + const auto epsilon = ctx->Attr("epsilon"); + const auto scale = ctx->Attr("scale"); + const T* down_scale_by_ptr = nullptr; + if (ctx->has_input("down_scale_by_tensor", 0)) { + const user_op::Tensor* down_scale_by_tensor = + ctx->Tensor4ArgNameAndIndex("down_scale_by_tensor", 0); + CHECK_EQ(down_scale_by_tensor->data_type(), unique_embeddings->data_type()); + CHECK_EQ(down_scale_by_tensor->shape().elem_cnt(), 1); + down_scale_by_ptr = down_scale_by_tensor->dptr(); + } + const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); + const float* learning_rate_ptr = learning_rate->dptr(); + const int64_t* train_step_ptr = ctx->Tensor4ArgNameAndIndex("train_step", 0)->dptr(); + const int64_t* skip_if_ptr = nullptr; + if (ctx->has_input("skip_if", 0)) { + const user_op::Tensor* skip_if = ctx->Tensor4ArgNameAndIndex("skip_if", 0); + CHECK_EQ(skip_if->shape().elem_cnt(), 1); + skip_if_ptr = skip_if->dptr(); + } + // update kernel + AdagradUpdateKernel + <<shape().elem_cnt()), kCudaThreadsNumPerBlock, 0, + ctx->stream()->As()->cuda_stream()>>>( + line_size, embedding_size, static_cast(scale), l1, l2, weight_decay, lr_decay, + epsilon, reinterpret_cast(num_unique_ids->dptr()), learning_rate_ptr, + train_step_ptr, down_scale_by_ptr, skip_if_ptr, embedding_grad->dptr(), + unique_embeddings->dptr(), updated_unique_embeddings->mut_dptr()); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CUDA_ADAGRAD_EMBEDDING_UPDATE_KERNEL(t_dtype_pair, g_type_pair, idx_dtype_pair) \ + REGISTER_USER_KERNEL("adagrad_embedding_update") \ + .SetCreateFn>() \ + .SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("num_unique_ids", 0) == OF_PP_PAIR_SECOND(idx_dtype_pair)) \ + && (user_op::HobDataType("embedding_grad", 0) == OF_PP_PAIR_SECOND(g_type_pair)) \ + && (user_op::HobDataType("unique_embeddings", 0) == OF_PP_PAIR_SECOND(t_dtype_pair))); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_ADAGRAD_EMBEDDING_UPDATE_KERNEL, + FLOATING_DATA_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, + IDX_DATA_TYPE_SEQ) + } // namespace oneflow diff --git a/oneflow/user/ops/one_embedding_ops.cpp b/oneflow/user/ops/one_embedding_ops.cpp index ac102e8bf73..bf5511f467a 100644 --- a/oneflow/user/ops/one_embedding_ops.cpp +++ b/oneflow/user/ops/one_embedding_ops.cpp @@ -323,4 +323,36 @@ Maybe CheckDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } +/* static */ Maybe AdagradEmbeddingUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + JUST(CheckDataShape(ctx)); + const Shape& unique_embeddings_shape = ctx->InputShape("unique_embeddings", 0); + CHECK_EQ_OR_RETURN(unique_embeddings_shape.At(1), 2 * ctx->InputShape("embedding_grad", 0).At(1)) + << "please adjust size_factor of MultiTableEmbedding's store_options to 2"; + *ctx->OutputShape("updated_unique_embeddings", 0) = unique_embeddings_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe AdagradEmbeddingUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AdagradEmbeddingUpdateOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Broadcast(user_op::OpArg("num_unique_ids", 0)) + .Split(user_op::OpArg("unique_embeddings", 0), 0) + .Split(user_op::OpArg("embedding_grad", 0), 0) + .Split(user_op::OpArg("updated_unique_embeddings", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AdagradEmbeddingUpdateOp::InferDataType(user_op::InferContext* ctx) { + JUST(CheckDataType(ctx)); + *ctx->OutputDType("updated_unique_embeddings", 0) = ctx->InputDType("unique_embeddings", 0); + return Maybe::Ok(); +} + } // namespace oneflow diff --git a/python/oneflow/test/modules/test_one_embedding_adagrad.py b/python/oneflow/test/modules/test_one_embedding_adagrad.py new file mode 100644 index 00000000000..516c27cef99 --- /dev/null +++ b/python/oneflow/test/modules/test_one_embedding_adagrad.py @@ -0,0 +1,155 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import tempfile +import os +import numpy as np +from oneflow.test_utils.test_util import GenArgDict +from optimizer_test_util import clip_grad_norm_np + +import oneflow as flow +from oneflow.nn.parameter import Parameter + + +def compare_with_numpy_adagrad( + test_case, weight_decay, lr_decay, scale, learning_rate, train_iters, +): + + num_rows = 500 + embedding_size = 128 + model_shape = (num_rows, embedding_size) + line_size = embedding_size * 2 + + num_valid_seq = np.random.randint(1, num_rows, (train_iters)) + skip_if_seq = [np.random.randint(2) for i in range(train_iters)] + + random_grad_seq = [] + for _ in range(train_iters): + random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32)) + + init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32) + + down_scale_by = 10 + epsilon = 1e-5 + + def adagrad_by_oneflow(): + unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to( + "cuda" + ) + lr_tensor = flow.tensor( + np.array(learning_rate).reshape(1,).astype(np.float32) + ).to("cuda") + down_scale_by_tensor = flow.tensor( + np.array(down_scale_by).astype(np.float32) + ).to("cuda") + + def train_one_iter( + num_valid, unique_embeddings, embedding_grad, skip_if, train_step + ): + return flow._C.one_embedding_adagrad_update( + num_valid, + unique_embeddings, + embedding_grad, + lr_tensor, + down_scale_by_tensor, + skip_if, + train_step, + scale, + weight_decay, + lr_decay, + epsilon, + ) + + for i in range(1, train_iters): + num_valid_tensor = flow.tensor( + np.array(num_valid_seq[i]).reshape(1,).astype(np.int32) + ).to("cuda") + grad_tensor = flow.tensor(random_grad_seq[i]).to("cuda") + skip_if_tensor = flow.tensor( + np.array(skip_if_seq[i]).reshape(1,).astype(np.int64) + ).to("cuda") + step_tensor = flow.tensor(np.array(i).reshape(1,).astype(np.int64)).to( + "cuda" + ) + updated_tensor = train_one_iter( + num_valid_tensor, + unique_embeddings_tensor, + grad_tensor, + skip_if_tensor, + step_tensor, + ) + unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[ + 0 : num_valid_seq[i] + ] + return unique_embeddings_tensor + + def adagrad_by_numpy(): + x = init_value[:, 0:embedding_size] + st = init_value[:, embedding_size:] + + def train_one_iter(iter, num_valid, grad, model, state): + grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by) + lr = learning_rate / (1 + iter * lr_decay) + state[0:num_valid] = ( + state[0:num_valid] + grad[0:num_valid] * grad[0:num_valid] + ) + model[0:num_valid] = ( + model[0:num_valid] + - lr / (np.sqrt(state[0:num_valid]) + epsilon) * grad[0:num_valid] + - lr * weight_decay * model[0:num_valid] + ) + return (model, state) + + for i in range(1, train_iters): + if skip_if_seq[i] > 0: + pass + else: + (x, st) = train_one_iter( + i, int(num_valid_seq[i]), random_grad_seq[i], x, st + ) + + return x, st + + oneflow_res = adagrad_by_oneflow().numpy() + of_model = oneflow_res[:, 0:embedding_size] + of_sum = oneflow_res[:, embedding_size:] + np_model, np_sum = adagrad_by_numpy() + test_case.assertTrue( + np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001) + ) + test_case.assertTrue( + np.allclose(of_sum.flatten(), np_sum.flatten(), rtol=0.001, atol=0.001) + ) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n1d() +class TestOptimizers(flow.unittest.TestCase): + def test_one_embedding_adagrad(test_case): + arg_dict = OrderedDict() + arg_dict["weight_decay"] = [0, 0.1] + arg_dict["lr_decay"] = [0, 0.1] + arg_dict["scale"] = [1, 0.1] + arg_dict["learning_rate"] = [0.3, 1.5] + arg_dict["train_iters"] = [10] + for arg in GenArgDict(arg_dict): + compare_with_numpy_adagrad(test_case, **arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_one_embedding_adam.py b/python/oneflow/test/modules/test_one_embedding_adam.py new file mode 100644 index 00000000000..2c237ba91a4 --- /dev/null +++ b/python/oneflow/test/modules/test_one_embedding_adam.py @@ -0,0 +1,200 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import tempfile +import os +import numpy as np +from oneflow.test_utils.test_util import GenArgDict +from optimizer_test_util import clip_grad_norm_np + +import oneflow as flow +from oneflow.nn.parameter import Parameter + + +def compare_with_numpy_adam( + test_case, + weight_decay, + scale, + learning_rate, + train_iters, + do_bias_correction, + beta1, + beta2, +): + + num_rows = 500 + embedding_size = 128 + model_shape = (num_rows, embedding_size) + line_size = embedding_size * 3 + + num_valid_seq = np.random.randint(1, num_rows, (train_iters)) + skip_if_seq = [np.random.randint(2) for i in range(train_iters)] + + random_grad_seq = [] + for _ in range(train_iters): + random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32)) + + init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32) + + down_scale_by = 10 + epsilon = 1e-5 + + def adam_by_oneflow(): + unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to( + "cuda" + ) + lr_tensor = flow.tensor( + np.array(learning_rate).reshape(1,).astype(np.float32) + ).to("cuda") + down_scale_by_tensor = flow.tensor( + np.array(down_scale_by).astype(np.float32) + ).to("cuda") + + def train_one_iter( + num_valid, + unique_embeddings, + embedding_grad, + skip_if, + bias_correction1, + bias_correction2, + ): + return flow._C.one_embedding_adam_update( + num_valid, + unique_embeddings, + embedding_grad, + lr_tensor, + down_scale_by_tensor, + skip_if, + bias_correction1, + bias_correction2, + scale, + weight_decay, + beta1, + beta2, + epsilon, + do_bias_correction, + ) + + for i in range(1, train_iters): + num_valid_tensor = flow.tensor( + np.array(num_valid_seq[i]).reshape(1,).astype(np.int32) + ).to("cuda") + grad_tensor = flow.tensor(random_grad_seq[i]).to("cuda") + skip_if_tensor = flow.tensor( + np.array(skip_if_seq[i]).reshape(1,).astype(np.int64) + ).to("cuda") + if do_bias_correction: + bias_correction1 = 1.0 - np.power(beta1, i) + bias_correction2 = 1.0 - np.power(beta2, i) + bias_correction1_tensor = flow.tensor( + np.array(bias_correction1).reshape(1,).astype(np.float32) + ).to("cuda") + bias_correction2_tensor = flow.tensor( + np.array(bias_correction2).reshape(1,).astype(np.float32) + ).to("cuda") + else: + bias_correction1_tensor = None + bias_correction2_tensor = None + updated_tensor = train_one_iter( + num_valid_tensor, + unique_embeddings_tensor, + grad_tensor, + skip_if_tensor, + bias_correction1_tensor, + bias_correction2_tensor, + ) + unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[ + 0 : num_valid_seq[i] + ] + return unique_embeddings_tensor + + def adam_by_numpy(): + x = init_value[:, 0:embedding_size] + m = init_value[:, embedding_size : 2 * embedding_size] + v = init_value[:, 2 * embedding_size : 3 * embedding_size] + + def np_train_one_iter(step, num_valid, grad, model, state_m, state_v): + grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by) + + bias_correction1 = 1.0 + bias_correction2 = 1.0 + + if do_bias_correction: + bias_correction1 = 1.0 - np.power(beta1, step) + bias_correction2 = 1.0 - np.power(beta2, step) + + state_m[0:num_valid] = ( + beta1 * state_m[0:num_valid] + (1 - beta1) * grad[0:num_valid] + ) + state_v[0:num_valid] = ( + beta2 * state_v[0:num_valid] + + (1 - beta2) * grad[0:num_valid] * grad[0:num_valid] + ) + denom = np.sqrt(state_v[0:num_valid]) / np.sqrt(bias_correction2) + epsilon + + model[0:num_valid] = ( + model[0:num_valid] + - ((learning_rate / bias_correction1) * state_m[0:num_valid] / denom) + - learning_rate * weight_decay * model[0:num_valid] + ) + return (model, state_m, state_v) + + for i in range(1, train_iters): # if step = 0, bias_correction2 is 0 + if skip_if_seq[i] > 0: + pass + else: + (x, m, v) = np_train_one_iter( + i, int(num_valid_seq[i]), random_grad_seq[i], x, m, v + ) + return x, m, v + + oneflow_res = adam_by_oneflow().numpy() + of_model = oneflow_res[:, 0:embedding_size] + of_m = oneflow_res[:, embedding_size : 2 * embedding_size] + of_v = oneflow_res[:, 2 * embedding_size : 3 * embedding_size] + np_model, np_m, np_v = adam_by_numpy() + test_case.assertTrue( + np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001) + ) + test_case.assertTrue( + np.allclose(of_m.flatten(), np_m.flatten(), rtol=0.001, atol=0.001) + ) + test_case.assertTrue( + np.allclose(of_v.flatten(), np_v.flatten(), rtol=0.001, atol=0.001) + ) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n1d() +class TestOptimizers(flow.unittest.TestCase): + def test_one_embedding_adam(test_case): + arg_dict = OrderedDict() + arg_dict["weight_decay"] = [0, 0.1] + arg_dict["scale"] = [1, 0.1] + arg_dict["learning_rate"] = [1, 1.5] + arg_dict["train_iters"] = [10] + arg_dict["do_bias_correction"] = [True, False] + arg_dict["beta1"] = [0.9, 0.8] + arg_dict["beta2"] = [0.9, 0.8] + + for arg in GenArgDict(arg_dict): + compare_with_numpy_adam(test_case, **arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_one_embedding_sgd.py b/python/oneflow/test/modules/test_one_embedding_sgd.py new file mode 100644 index 00000000000..9b709229d30 --- /dev/null +++ b/python/oneflow/test/modules/test_one_embedding_sgd.py @@ -0,0 +1,146 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict +import tempfile +import os +import numpy as np +from oneflow.test_utils.test_util import GenArgDict +from optimizer_test_util import clip_grad_norm_np + +import oneflow as flow +from oneflow.nn.parameter import Parameter + + +def compare_with_numpy_sgd( + test_case, momentum, weight_decay, scale, learning_rate, train_iters, +): + + num_rows = 500 + embedding_size = 128 + model_shape = (num_rows, embedding_size) + line_size = embedding_size * 2 if momentum > 0 else embedding_size + + num_valid_seq = np.random.randint(1, num_rows, (train_iters)) + skip_if_seq = [np.random.randint(2) for i in range(train_iters)] + + random_grad_seq = [] + for _ in range(train_iters): + random_grad_seq.append(np.random.uniform(size=model_shape).astype(np.float32)) + + init_value = np.random.uniform(size=(num_rows, line_size)).astype(np.float32) + + down_scale_by = 10 + + def sgd_by_oneflow(): + unique_embeddings_tensor = flow.tensor(init_value, requires_grad=False).to( + "cuda" + ) + lr_tensor = flow.tensor( + np.array(learning_rate).reshape(1,).astype(np.float32) + ).to("cuda") + down_scale_by_tensor = flow.tensor( + np.array(down_scale_by).astype(np.float32) + ).to("cuda") + + def train_one_iter(num_valid, unique_embeddings, embedding_grad, skip_if): + return flow._C.one_embedding_sgd_update( + num_valid, + unique_embeddings, + embedding_grad, + lr_tensor, + down_scale_by_tensor, + skip_if, + scale, + weight_decay, + momentum, + ) + + for i in range(train_iters): + num_valid_tensor = flow.tensor( + np.array(num_valid_seq[i]).reshape(1,).astype(np.int32) + ).to("cuda") + grad_tensor = flow.tensor(random_grad_seq[i]).to("cuda") + skip_if_tensor = flow.tensor( + np.array(skip_if_seq[i]).reshape(1,).astype(np.int64) + ).to("cuda") + updated_tensor = train_one_iter( + num_valid_tensor, unique_embeddings_tensor, grad_tensor, skip_if_tensor + ) + unique_embeddings_tensor[0 : num_valid_seq[i]] = updated_tensor[ + 0 : num_valid_seq[i] + ] + return unique_embeddings_tensor + + def sgd_by_numpy(): + x = init_value[:, 0:embedding_size] + vt = init_value[:, embedding_size:] + + def train_one_iter(num_valid, grad, model, state): + grad[0:num_valid] = grad[0:num_valid] * (scale / down_scale_by) + next_state = ( + momentum * state[0:num_valid] if momentum > 0 else 0 + ) - learning_rate * grad[0:num_valid] + if momentum > 0: + state[0:num_valid] = next_state + model[0:num_valid] = ( + model[0:num_valid] + + next_state + - learning_rate * weight_decay * model[0:num_valid] + ) + return (model, state) + + for i in range(train_iters): + if skip_if_seq[i] > 0: + pass + else: + (x, vt) = train_one_iter( + int(num_valid_seq[i]), random_grad_seq[i], x, vt + ) + return x, vt + + oneflow_res = sgd_by_oneflow().numpy() + of_model = oneflow_res[:, 0:embedding_size] + of_momentum = oneflow_res[:, embedding_size:] + np_model, np_momentum = sgd_by_numpy() + test_case.assertTrue( + np.allclose(of_model.flatten(), np_model.flatten(), rtol=0.001, atol=0.001) + ) + if momentum > 0: + test_case.assertTrue( + np.allclose( + of_momentum.flatten(), np_momentum.flatten(), rtol=0.001, atol=0.001 + ) + ) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n1d() +class TestOptimizers(flow.unittest.TestCase): + def test_one_embedding_sgd(test_case): + arg_dict = OrderedDict() + arg_dict["momentum"] = [0, 0.9] + arg_dict["weight_decay"] = [0, 0.1] + arg_dict["scale"] = [1, 0.1] + arg_dict["learning_rate"] = [1, 0.9] + arg_dict["train_iters"] = [10] + for arg in GenArgDict(arg_dict): + compare_with_numpy_sgd(test_case, **arg) + + +if __name__ == "__main__": + unittest.main()