Skip to content

Commit

Permalink
OneEmbedding update kernel add adagrad (#7999)
Browse files Browse the repository at this point in the history
* column to table

* multi column

* fix typo and use model update util

* fix import

* add doc

* refine

* one_embedding update use model_update kernel util

* add doc

* add adagrad

* refine

* revert update

* refine

* refine

* support for adagrad

* fix tidy

* set state_initializer in pass

* state_initializer in pass

* fix tidy

* skip predict_job_has_optimizer_state prefetch, address review

* add test sgd

* add adam test

* add adagrad test

* refine

* add l1 l2

* address review

* skip if cpu

* Update python/oneflow/test/modules/test_one_embedding_adam.py

* Update python/oneflow/test/modules/test_one_embedding_sgd.py

* Update python/oneflow/test/modules/test_one_embedding_adagrad.py

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
guo-ran and mergify[bot] authored Apr 20, 2022
1 parent ff84524 commit b6becb9
Show file tree
Hide file tree
Showing 10 changed files with 878 additions and 11 deletions.
12 changes: 12 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,18 @@
signature: "TensorTuple (Tensor keys, Tensor values=None, Int32 num_tables) => OneEmbeddingUniqueKeyValuePair"
bind_python: True

- name: "one_embedding_sgd_update"
signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate, Tensor down_scale_by_tensor, Tensor skip_if, Double scale, Float weight_decay, Float momentum) => OneEmbeddingSgdUpdate"
bind_python: True

- name: "one_embedding_adam_update"
signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate, Tensor down_scale_by_tensor, Tensor skip_if, Tensor bias_correction1=None, Tensor bias_correction2=None, Double scale=1.0, Float weight_decay=0.0, Float beta1=0.9, Float beta2=0.999, Float epsilon=0, Bool do_bias_correction=True) => OneEmbeddingAdamUpdate"
bind_python: True

- name: "one_embedding_adagrad_update"
signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate, Tensor down_scale_by_tensor, Tensor skip_if, Tensor train_step, Double scale=1.0, Float weight_decay=0.0, Float lr_decay=0.0, Float epsilon=0) => OneEmbeddingAdagradUpdate"
bind_python: True

- name: "einsum"
signature: "Tensor (String equation, TensorTuple operands) => EinSum"
bind_python: True
Expand Down
161 changes: 161 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2535,6 +2535,164 @@ class OneEmbeddingUniqueKeyValuePairFunctor {
std::shared_ptr<OpExpr> op_no_input_value_;
};

class OneEmbeddingSgdUpdateFunctor {
public:
OneEmbeddingSgdUpdateFunctor() {
// This functor is just for unittest
sgd_op_ = CHECK_JUST(one::OpBuilder("sgd_embedding_update")
.Input("num_unique_ids")
.Input("unique_embeddings")
.Input("embedding_grad")
.Input("learning_rate")
.Input("down_scale_by_tensor")
.Input("skip_if")
.Output("updated_unique_embeddings")
.Build());
momentum_op_ = CHECK_JUST(one::OpBuilder("momentum_embedding_update")
.Input("num_unique_ids")
.Input("unique_embeddings")
.Input("embedding_grad")
.Input("learning_rate")
.Input("down_scale_by_tensor")
.Input("skip_if")
.Output("updated_unique_embeddings")
.Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,
const std::shared_ptr<one::Tensor>& unique_embeddings,
const std::shared_ptr<one::Tensor>& embedding_grad,
const std::shared_ptr<one::Tensor>& learning_rate,
const std::shared_ptr<one::Tensor>& down_scale_by_tensor,
const std::shared_ptr<one::Tensor>& skip_if, const double scale,
const float weight_decay, const float momentum) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("scale", scale));
JUST(attrs.SetAttr<float>("weight_decay", weight_decay));
if (momentum == 0) {
return OpInterpUtil::Dispatch<Tensor>(*sgd_op_,
{num_unique_ids, unique_embeddings, embedding_grad,
learning_rate, down_scale_by_tensor, skip_if},
attrs);
} else {
JUST(attrs.SetAttr<float>("beta", momentum));
return OpInterpUtil::Dispatch<Tensor>(*momentum_op_,
{num_unique_ids, unique_embeddings, embedding_grad,
learning_rate, down_scale_by_tensor, skip_if},
attrs);
}
}

private:
std::shared_ptr<OpExpr> sgd_op_;
std::shared_ptr<OpExpr> momentum_op_;
};

class OneEmbeddingAdamUpdateFunctor {
public:
OneEmbeddingAdamUpdateFunctor() {
// This functor is just for unittest
no_bias_correction_op_ = CHECK_JUST(one::OpBuilder("adam_embedding_update")
.Input("num_unique_ids")
.Input("unique_embeddings")
.Input("embedding_grad")
.Input("learning_rate")
.Input("down_scale_by_tensor")
.Input("skip_if")
.Output("updated_unique_embeddings")
.Build());
do_bias_correction_op_ = CHECK_JUST(one::OpBuilder("adam_embedding_update")
.Input("num_unique_ids")
.Input("unique_embeddings")
.Input("embedding_grad")
.Input("learning_rate")
.Input("down_scale_by_tensor")
.Input("skip_if")
.Input("bias_correction1")
.Input("bias_correction2")
.Output("updated_unique_embeddings")
.Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,
const std::shared_ptr<one::Tensor>& unique_embeddings,
const std::shared_ptr<one::Tensor>& embedding_grad,
const std::shared_ptr<one::Tensor>& learning_rate,
const std::shared_ptr<one::Tensor>& down_scale_by_tensor,
const std::shared_ptr<one::Tensor>& skip_if,
const Optional<one::Tensor>& bias_correction1,
const Optional<one::Tensor>& bias_correction2, const double scale,
const float weight_decay, const float beta1, const float beta2,
const float epsilon, const bool do_bias_correction) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("scale", scale));
JUST(attrs.SetAttr<float>("weight_decay", weight_decay));
JUST(attrs.SetAttr<float>("beta1", beta1));
JUST(attrs.SetAttr<float>("beta2", beta2));
JUST(attrs.SetAttr<float>("epsilon", epsilon));
JUST(attrs.SetAttr<bool>("do_bias_correction", do_bias_correction));
if (do_bias_correction) {
CHECK(bias_correction1);
CHECK(bias_correction2);
return OpInterpUtil::Dispatch<Tensor>(
*do_bias_correction_op_,
{num_unique_ids, unique_embeddings, embedding_grad, learning_rate, down_scale_by_tensor,
skip_if, JUST(bias_correction1), JUST(bias_correction2)},
attrs);
} else {
return OpInterpUtil::Dispatch<Tensor>(*no_bias_correction_op_,
{num_unique_ids, unique_embeddings, embedding_grad,
learning_rate, down_scale_by_tensor, skip_if},
attrs);
}
}

private:
std::shared_ptr<OpExpr> no_bias_correction_op_;
std::shared_ptr<OpExpr> do_bias_correction_op_;
};

class OneEmbeddingAdagradUpdateFunctor {
public:
OneEmbeddingAdagradUpdateFunctor() {
// This functor is just for unittest
op_ = CHECK_JUST(one::OpBuilder("adagrad_embedding_update")
.Input("num_unique_ids")
.Input("unique_embeddings")
.Input("embedding_grad")
.Input("learning_rate")
.Input("down_scale_by_tensor")
.Input("skip_if")
.Input("train_step")
.Output("updated_unique_embeddings")
.Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,
const std::shared_ptr<one::Tensor>& unique_embeddings,
const std::shared_ptr<one::Tensor>& embedding_grad,
const std::shared_ptr<one::Tensor>& learning_rate,
const std::shared_ptr<one::Tensor>& down_scale_by_tensor,
const std::shared_ptr<one::Tensor>& skip_if,
const std::shared_ptr<one::Tensor>& train_step, const double scale,
const float weight_decay, const float lr_decay,
const float epsilon) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("scale", scale));
JUST(attrs.SetAttr<float>("weight_decay", weight_decay));
JUST(attrs.SetAttr<float>("lr_decay", lr_decay));
JUST(attrs.SetAttr<float>("epsilon", epsilon));
return OpInterpUtil::Dispatch<Tensor>(
*op_,
{num_unique_ids, unique_embeddings, embedding_grad, learning_rate, down_scale_by_tensor,
skip_if, train_step},
attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class RocAucScoreFunctor {
public:
RocAucScoreFunctor() {
Expand Down Expand Up @@ -2628,6 +2786,9 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
"OneEmbeddingEmbeddingGradientShuffle");
m.add_functor<impl::OneEmbeddingLookupFunctor>("OneEmbeddingLookup");
m.add_functor<impl::OneEmbeddingUniqueKeyValuePairFunctor>("OneEmbeddingUniqueKeyValuePair");
m.add_functor<impl::OneEmbeddingSgdUpdateFunctor>("OneEmbeddingSgdUpdate");
m.add_functor<impl::OneEmbeddingAdamUpdateFunctor>("OneEmbeddingAdamUpdate");
m.add_functor<impl::OneEmbeddingAdagradUpdateFunctor>("OneEmbeddingAdagradUpdate");
m.add_functor<impl::RocAucScoreFunctor>("RocAucScore");
};

Expand Down
27 changes: 23 additions & 4 deletions oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ void MakeConstantInitializerAttr(const int64_t embedding_size, const int64_t lin

void BuildEmbeddingUpdate(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_builder,
const ParallelConf& parallel_conf, const int64_t embedding_size,
const int64_t line_size, const std::string& embedding_name,
const OptimizerConf& optimizer_conf,
const int64_t line_size, const float l1, const float l2,
const std::string& embedding_name, const OptimizerConf& optimizer_conf,
const user_op::UserOpConfWrapper& embedding_op,
const std::string& num_unique_ids_lbn, const std::string& unique_ids_lbn,
const std::string& unique_values_lbn,
Expand Down Expand Up @@ -482,7 +482,12 @@ void BuildEmbeddingUpdate(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder*
.Input("bias_correction2", bias_correction2_lbn);
}
} else if (optimizer_conf.has_adagrad_conf()) {
UNIMPLEMENTED();
const AdagradModelUpdateConf& adagrad_conf = optimizer_conf.adagrad_conf();
state_constant_init_values.push_back(adagrad_conf.initial_accumulator_value());
embedding_update_op_builder.OpTypeName("adagrad_embedding_update")
.Input("train_step", train_conf.train_step_lbn())
.Attr<float>("lr_decay", adagrad_conf.lr_decay())
.Attr<float>("epsilon", adagrad_conf.epsilon());
} else {
UNIMPLEMENTED();
}
Expand All @@ -493,6 +498,8 @@ void BuildEmbeddingUpdate(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder*
.Input("embedding_grad", embedding_grad_lbn)
.Input("learning_rate", learning_rate_lbn)
.Attr<float>("weight_decay", optimizer_conf.weight_decay_conf().weight_decay_rate())
.Attr<float>("l1", l1)
.Attr<float>("l2", l2)
.Output("updated_unique_embeddings");
double scale = GetLossInstanceNumScaleFactor(op_graph, job_builder);
if (train_conf.has_dynamic_loss_scale_policy()) {
Expand Down Expand Up @@ -690,12 +697,24 @@ Maybe<void> ReplaceEmbeddingOps::Apply(const OpGraph& op_graph, JobBuilder* job_
if (found_embedding_optimizer == true) { break; }
}
CHECK_EQ(found_embedding_optimizer, true);

const OpNode* shadow_node = op_graph.OpNode4OpName(shadow_op_name);
const VariableOpConf& shadow_variable_conf = shadow_node->op().op_conf().variable_conf();
float l1 = 0.0;
float l2 = 0.0;
if (shadow_variable_conf.has_regularizer()) {
const RegularizerConf& regularizer_conf = shadow_variable_conf.regularizer();
if (regularizer_conf.has_l1_l2_conf()) {
l1 = regularizer_conf.l1_l2_conf().l1();
l2 = regularizer_conf.l1_l2_conf().l2();
}
}
const std::string& learning_rate_lbn =
AddScheduleOp(op_graph, job_builder, embedding_optimizer_conf,
"System-Train-LearningRate-Scheduler_" + NewUniqueId());

BuildEmbeddingUpdate(ctx, op_graph, job_builder, op_node->parallel_desc().parallel_conf(),
embedding_size, options.LineSize(), options.Name(),
embedding_size, options.LineSize(), l1, l2, options.Name(),
embedding_optimizer_conf, embedding_op, num_unique_ids_lbn,
unique_ids_lbn, unique_values_lbn, embedding_grad_lbn,
learning_rate_lbn, &state_initializer);
Expand Down
34 changes: 34 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9548,6 +9548,8 @@ def OneFlow_SgdEmbeddingUpdateOp : OneFlow_BaseOp<"sgd_embedding_update", [AttrS
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "1.">:$scale,
DefaultValuedAttr<F32Attr, "0.">:$l1,
DefaultValuedAttr<F32Attr, "0.">:$l2,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay
);
let same_output_regst_num = 1;
Expand All @@ -9571,6 +9573,8 @@ def OneFlow_MomentumEmbeddingUpdateOp : OneFlow_BaseOp<"momentum_embedding_updat
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "1.">:$scale,
DefaultValuedAttr<F32Attr, "0.">:$l1,
DefaultValuedAttr<F32Attr, "0.">:$l2,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay,
DefaultValuedAttr<F32Attr, "0.9">:$beta
);
Expand All @@ -9597,6 +9601,8 @@ def OneFlow_AdamEmbeddingUpdateOp : OneFlow_BaseOp<"adam_embedding_update", [Att
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "1.">:$scale,
DefaultValuedAttr<F32Attr, "0.">:$l1,
DefaultValuedAttr<F32Attr, "0.">:$l2,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay,
DefaultValuedAttr<F32Attr, "0.9">:$beta1,
DefaultValuedAttr<F32Attr, "0.999">:$beta2,
Expand All @@ -9610,6 +9616,34 @@ def OneFlow_AdamEmbeddingUpdateOp : OneFlow_BaseOp<"adam_embedding_update", [Att
let has_data_type_infer_fn = 1;
}

def OneFlow_AdagradEmbeddingUpdateOp : OneFlow_BaseOp<"adagrad_embedding_update", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$num_unique_ids,
OneFlow_Tensor:$unique_embeddings,
OneFlow_Tensor:$embedding_grad,
OneFlow_Tensor:$learning_rate,
OneFlow_Tensor:$train_step,
Optional<OneFlow_Tensor>:$down_scale_by_tensor,
Optional<OneFlow_Tensor>:$skip_if
);
let output = (outs
OneFlow_Tensor:$updated_unique_embeddings
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "1.">:$scale,
DefaultValuedAttr<F32Attr, "0.">:$l1,
DefaultValuedAttr<F32Attr, "0.">:$l2,
DefaultValuedAttr<F32Attr, "0.">:$weight_decay,
DefaultValuedAttr<F32Attr, "0.">:$lr_decay,
DefaultValuedAttr<F32Attr, "0.">:$epsilon
);
let same_output_regst_num = 1;
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_EmbeddingPutOp : OneFlow_BaseOp<"embedding_put", [DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$num_unique_ids,
Expand Down
3 changes: 2 additions & 1 deletion oneflow/user/kernels/model_update_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ struct AdagradUpdateFunctor {
CastScaleRegularizeGradientFunctor<T, G>()(*model_diff, model_val, scale, l1, l2);
const T next_sum = *sum + model_diff_t * model_diff_t;
*sum = next_sum;
*model = model_val - learning_rate / (sqrt(next_sum) + epsilon) * model_diff_t;
*model = model_val - learning_rate / (sqrt(next_sum) + epsilon) * model_diff_t
- learning_rate * weight_decay * model_val;
}
};

Expand Down
Loading

0 comments on commit b6becb9

Please sign in to comment.