Skip to content

Commit

Permalink
[OneEmbedding] Dev ftrl optimizer (#8008)
Browse files Browse the repository at this point in the history
* column to table

* multi column

* fix typo and use model update util

* fix import

* add doc

* refine

* one_embedding update use model_update kernel util

* add doc

* add adagrad

* refine

* revert update

* add ftrl optimizer

* refine

* Add python impl

* refine

* Add cpp dispatch ftrl

* support for adagrad

* fix param group initial value bug

* support eager ftrl optimizer

* use model_update_util and weight decay

* support graph ftrl optimizer

* fix tidy

* support ftrl embedding update

* set state_initializer in pass

* state_initializer in pass

* fix tidy

* skip predict_job_has_optimizer_state prefetch, address review

* add test sgd

* add adam test

* add adagrad test

* refine

* add l1 l2

* add embedding update ftrl

* support more optimizer to fuse update

* Add docs

* fix initial value for ftrl update optimizer

* Fix guoran comment

* add new line and fix format

* fix juncheng comment

* fix merge conflict

* Fix clang analysis

* fix adagrad

* fix format

* only test weight_decay = 0.0

* fix cpu module test

Co-authored-by: guo-ran <360112263@qq.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Juncheng <liujuncheng1022@gmail.com>
Co-authored-by: Shenghang Tsai <jackalcooper@gmail.com>
  • Loading branch information
5 people authored Apr 27, 2022
1 parent cb63722 commit 8040eb9
Show file tree
Hide file tree
Showing 25 changed files with 1,574 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/source/one_embedding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ OneFlow one_embedding operations.
.. autofunction:: oneflow.one_embedding.make_normal_initializer
.. autofunction:: oneflow.one_embedding.make_table_options
.. autofunction:: oneflow.one_embedding.make_table
.. automodule:: oneflow.one_embedding
:members: Ftrl
.. autofunction:: oneflow.one_embedding.make_persistent_table_reader
.. autofunction:: oneflow.one_embedding.make_persistent_table_writer
17 changes: 17 additions & 0 deletions oneflow/api/python/functional/dispatch_stateful_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,23 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
});
m.add_functor("DispatchFtrlUpdate",
[](const std::shared_ptr<OpExpr>& op, const TensorTuple& inputs,
float learning_rate, double scale, float l1, float l2, float lr_power,
float lambda1, float lambda2, float beta, float weight_decay) -> Maybe<void> {
MutableAttrMap attrs;
JUST(attrs.SetAttr("learning_rate_val", learning_rate));
JUST(attrs.SetAttr("scale", scale));
JUST(attrs.SetAttr("l1", l1));
JUST(attrs.SetAttr("l2", l2));
JUST(attrs.SetAttr("lr_power", lr_power));
JUST(attrs.SetAttr("lambda1", lambda1));
JUST(attrs.SetAttr("lambda2", lambda2));
JUST(attrs.SetAttr("beta", beta));
JUST(attrs.SetAttr("weight_decay", weight_decay));
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op, inputs, attrs));
return Maybe<void>::Ok();
});
m.add_functor("DispatchEagerNcclAllReduce",
[](const std::shared_ptr<OpExpr>& op, const std::shared_ptr<Tensor>& input,
const std::string& parallel_conf, bool async_launch) -> Maybe<Tensor> {
Expand Down
4 changes: 4 additions & 0 deletions oneflow/api/python/functional/dispatch_stateful_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Float bias_correction1=1.0, Float bias_correction2=1.0, Double scale=1.0, Float l1=0, Float l2=0, Float beta1=0.9, Float beta2=0.999, Float epsilon=1e-8, Float weight_decay=0, Bool do_bias_correction=True) => DispatchLambUpdate"
bind_python: True

- name: "dispatch_ftrl_update"
signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_power, Float lambda1, Float lambda2, Float beta, Float weight_decay=0) => DispatchFtrlUpdate"
bind_python: True

- name: "dispatch_eager_nccl_all_reduce"
signature: "Tensor (OpExpr op, Tensor input, String parallel_conf, Bool async_launch=False) => DispatchEagerNcclAllReduce"
bind_python: True
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,10 @@
signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate, Tensor down_scale_by_tensor, Tensor skip_if, Tensor train_step, Double scale=1.0, Float weight_decay=0.0, Float lr_decay=0.0, Float epsilon=0) => OneEmbeddingAdagradUpdate"
bind_python: True

- name: "one_embedding_ftrl_update"
signature: "Tensor (Tensor num_unique_ids, Tensor unique_embeddings, Tensor embedding_grad, Tensor learning_rate, Tensor down_scale_by_tensor, Tensor skip_if, Double scale, Float weight_decay, Float lr_power, Float lambda1, Float lambda2, Float beta) => OneEmbeddingFtrlUpdate"
bind_python: True

- name: "einsum"
signature: "Tensor (String equation, TensorTuple operands) => EinSum"
bind_python: True
Expand Down
43 changes: 42 additions & 1 deletion oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2700,6 +2700,46 @@ class OneEmbeddingAdagradUpdateFunctor {
std::shared_ptr<OpExpr> op_;
};

class OneEmbeddingFtrlUpdateFunctor {
public:
OneEmbeddingFtrlUpdateFunctor() {
// This functor is just for unittest
op_ = CHECK_JUST(one::OpBuilder("ftrl_embedding_update")
.Input("num_unique_ids")
.Input("unique_embeddings")
.Input("embedding_grad")
.Input("learning_rate")
.Input("down_scale_by_tensor")
.Input("skip_if")
.Output("updated_unique_embeddings")
.Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& num_unique_ids,
const std::shared_ptr<one::Tensor>& unique_embeddings,
const std::shared_ptr<one::Tensor>& embedding_grad,
const std::shared_ptr<one::Tensor>& learning_rate,
const std::shared_ptr<one::Tensor>& down_scale_by_tensor,
const std::shared_ptr<one::Tensor>& skip_if, const double scale,
const float weight_decay, const float lr_power, const float lambda1,
const float lambda2, const float beta) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("scale", scale));
JUST(attrs.SetAttr<float>("weight_decay", weight_decay));
JUST(attrs.SetAttr<float>("lr_power", lr_power));
JUST(attrs.SetAttr<float>("lambda1", lambda1));
JUST(attrs.SetAttr<float>("lambda2", lambda2));
JUST(attrs.SetAttr<float>("beta", beta));
return OpInterpUtil::Dispatch<Tensor>(*op_,
{num_unique_ids, unique_embeddings, embedding_grad,
learning_rate, down_scale_by_tensor, skip_if},
attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class RocAucScoreFunctor {
public:
RocAucScoreFunctor() {
Expand Down Expand Up @@ -2796,8 +2836,9 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::OneEmbeddingSgdUpdateFunctor>("OneEmbeddingSgdUpdate");
m.add_functor<impl::OneEmbeddingAdamUpdateFunctor>("OneEmbeddingAdamUpdate");
m.add_functor<impl::OneEmbeddingAdagradUpdateFunctor>("OneEmbeddingAdagradUpdate");
m.add_functor<impl::OneEmbeddingFtrlUpdateFunctor>("OneEmbeddingFtrlUpdate");
m.add_functor<impl::RocAucScoreFunctor>("RocAucScore");
};
}

} // namespace functional
} // namespace one
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/job/job_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ message AdagradModelUpdateConf {
required float epsilon = 3 [default = 1e-10];
}

message FtrlModelUpdateConf {
required float initial_accumulator_value = 1 [default = 0.1];
required float lr_power = 2 [default = 0.5];
optional float lambda1 = 3 [default = 0.0];
optional float lambda2 = 4 [default = 0.0];
optional float beta = 5 [default = 0.0];
}

message ClipByGlobalNormConf {
optional float max_norm = 1 [default = 1.0];
optional double norm_type = 2 [default = 2.0];
Expand Down Expand Up @@ -98,6 +106,7 @@ message OptimizerConf {
LazyAdamModelUpdateConf lazy_adam_conf = 1005;
LambModelUpdateConf lamb_conf = 1006;
AdagradModelUpdateConf adagrad_conf = 1007;
FtrlModelUpdateConf ftrl_conf = 1008;
}
}

Expand All @@ -114,6 +123,7 @@ message NormalModelUpdateOpUserConf {
LazyAdamModelUpdateConf lazy_adam_conf = 1005;
LambModelUpdateConf lamb_conf = 1006;
AdagradModelUpdateConf adagrad_conf = 1007;
FtrlModelUpdateConf ftrl_conf = 1008;
}
}

Expand Down
96 changes: 96 additions & 0 deletions oneflow/core/job_rewriter/ftrl_optm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/job/initializer_conf.pb.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/job_conf.pb.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/job_rewriter/optimizer.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/variable_op.h"

namespace oneflow {

namespace {

std::string GenVariableOutputLbn(const OperatorConf& op_conf) {
CHECK(op_conf.has_variable_conf());
return GenLogicalBlobName(op_conf.name(), op_conf.variable_conf().out());
}

OperatorConf GenerateFtrlHelperVariableConf(const VariableOp& op, const std::string& name,
const float initial_value) {
OperatorConf helper_variable_op(op.op_conf());
helper_variable_op.set_name(op.op_name() + "-" + name);
helper_variable_op.mutable_variable_conf()->set_out("out");
InitializerConf constant_initializer;
constant_initializer.mutable_constant_conf()->set_value(initial_value);
*(helper_variable_op.mutable_variable_conf()->mutable_initializer()) = constant_initializer;
helper_variable_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());
return helper_variable_op;
}

void GenerateFtrlOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node,
const std::string& model_diff_lbn,
const OptimizerConf& optimizer_conf, JobBuilder* job_builder) {
const VariableOp* var_op = dynamic_cast<const VariableOp*>(&var_op_node.op());
CHECK_NOTNULL(var_op);

user_op::UserOpConfWrapperBuilder ftrl_update_op_builder(var_op->op_name() + "_optimizer");
float lr_power = 0.0;
float initial_accumulator_value = 0.0;
float lambda1 = 0.0;
float lambda2 = 0.0;
float beta = 0.0;

const FtrlModelUpdateConf& ftrl_conf = optimizer_conf.ftrl_conf();
lr_power = ftrl_conf.lr_power();
initial_accumulator_value = ftrl_conf.initial_accumulator_value();
lambda1 = ftrl_conf.lambda1();
lambda2 = ftrl_conf.lambda2();
beta = ftrl_conf.beta();

const std::string& learning_rate_lbn = optimizer_conf.learning_rate_lbn();
OperatorConf accumulator_var(
GenerateFtrlHelperVariableConf(*var_op, "accumulate", initial_accumulator_value));
OperatorConf z_var(GenerateFtrlHelperVariableConf(*var_op, "z", 0.0));
job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {accumulator_var, z_var});

ftrl_update_op_builder.OpTypeName("ftrl_update")
.Input("model", GenLogicalBlobName(var_op->BnInOp2Lbi("out")))
.Input("model_diff", model_diff_lbn)
.Input("learning_rate", learning_rate_lbn)
.Input("accumulate", GenVariableOutputLbn(accumulator_var))
.Input("z", GenVariableOutputLbn(z_var))
.Attr<float>("lr_power", lr_power)
.Attr<float>("lambda1", lambda1)
.Attr<float>("lambda2", lambda2)
.Attr<float>("beta", beta)
.Attr<float>("weight_decay", GetOptimizerWeightDecayRate(optimizer_conf, *var_op))
.ScopeSymbolId(var_op->op_conf().scope_symbol_id());

SetDynamicLossScaleSkipIf(ctx, &ftrl_update_op_builder);
const auto ftrl_update_op = ftrl_update_op_builder.Build();
job_builder->AddOps(var_op_node.parallel_desc().parallel_conf(), {ftrl_update_op.op_conf()});
}

} // namespace

REGISTER_OPTIMIZER(OptimizerConf::kFtrlConf, &GenerateFtrlOptimizerOpConf);

} // namespace oneflow
31 changes: 30 additions & 1 deletion oneflow/core/job_rewriter/fuse_update_ops_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
&& user_op_conf.op_type_name() != "momentum_update"
&& user_op_conf.op_type_name() != "adam_update"
&& user_op_conf.op_type_name() != "rmsprop_update"
&& user_op_conf.op_type_name() != "lars_update") {
&& user_op_conf.op_type_name() != "lars_update"
&& user_op_conf.op_type_name() != "adagrad_update"
&& user_op_conf.op_type_name() != "lamb_update"
&& user_op_conf.op_type_name() != "ftrl_update") {
return;
}
if (user_op_conf.attr<double>("scale") != 1.0 || user_op_conf.attr<float>("l1") != 0.0f
Expand Down Expand Up @@ -146,6 +149,8 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu

if (!fused) { return; }

const TrainConf& train_conf = job_builder->job().job_conf().train_conf();

user_op::UserOpConfWrapperBuilder fused_op_builder(user_op_conf.op_name());
fused_op_builder.OpTypeName(user_op_conf.op_type_name())
.Input("model", user_op_conf.input("model", 0))
Expand Down Expand Up @@ -195,6 +200,30 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
.Attr<float>("momentum_beta", user_op_conf.attr<float>("momentum_beta"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"))
.Attr<float>("lars_coefficient", user_op_conf.attr<float>("lars_coefficient"));
} else if (user_op_conf.op_type_name() == "adagrad_update") {
fused_op_builder.Input("sum", user_op_conf.input("sum", 0))
.Input("train_step", train_conf.train_step_lbn())
.Attr<float>("lr_decay", user_op_conf.attr<float>("lr_decay"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"));
} else if (user_op_conf.op_type_name() == "lamb_update") {
fused_op_builder.Input("m", user_op_conf.input("m", 0))
.Input("v", user_op_conf.input("v", 0))
.Attr<float>("beta1", user_op_conf.attr<float>("beta1"))
.Attr<float>("beta2", user_op_conf.attr<float>("beta2"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"));
if (user_op_conf.has_input("bias_correction1", 0)) {
fused_op_builder.Input("bias_correction1", user_op_conf.input("bias_correction1", 0));
}
if (user_op_conf.has_input("bias_correction2", 0)) {
fused_op_builder.Input("bias_correction2", user_op_conf.input("bias_correction2", 0));
}
} else if (user_op_conf.op_type_name() == "ftrl_update") {
fused_op_builder.Input("accumulate", user_op_conf.input("accumulate", 0))
.Input("z", user_op_conf.input("z", 0))
.Attr<float>("lr_power", user_op_conf.attr<float>("lr_power"))
.Attr<float>("lambda1", user_op_conf.attr<float>("lambda1"))
.Attr<float>("lambda2", user_op_conf.attr<float>("lambda2"))
.Attr<float>("beta", user_op_conf.attr<float>("beta"));
} else {
UNIMPLEMENTED();
}
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,16 @@ void BuildEmbeddingUpdate(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder*
.Input("train_step", train_conf.train_step_lbn())
.Attr<float>("lr_decay", adagrad_conf.lr_decay())
.Attr<float>("epsilon", adagrad_conf.epsilon());
} else if (optimizer_conf.has_ftrl_conf()) {
const FtrlModelUpdateConf& ftrl_conf = optimizer_conf.ftrl_conf();
state_constant_init_values.push_back(ftrl_conf.initial_accumulator_value());
// For `z`, its init value is 0.0.
state_constant_init_values.push_back(0.0);
embedding_update_op_builder.OpTypeName("ftrl_embedding_update")
.Attr<float>("lr_power", ftrl_conf.lr_power())
.Attr<float>("lambda1", ftrl_conf.lambda1())
.Attr<float>("lambda2", ftrl_conf.lambda2())
.Attr<float>("beta", ftrl_conf.beta());
} else {
UNIMPLEMENTED();
}
Expand Down
Loading

0 comments on commit 8040eb9

Please sign in to comment.