Skip to content

Commit

Permalink
one_embedding amp default fp16 (#8174)
Browse files Browse the repository at this point in the history
* fix different type error

* add embedding_placeholder to white_list

Co-authored-by: Juncheng <liujuncheng1022@gmail.com>
Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored May 13, 2022
1 parent 5c0ff0d commit 944ad62
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 29 deletions.
3 changes: 2 additions & 1 deletion oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ const AMPList& AutoMixedPrecisionLists::WhiteList() {
"prelu",
"tf_prelu",
"cublas_fused_mlp",
"fused_dot_feature_interaction"};
"fused_dot_feature_interaction",
"embedding_lookup_placeholder"};
return white_list;
}

Expand Down
54 changes: 28 additions & 26 deletions oneflow/core/job_rewriter/replace_embedding_ops_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,17 @@ void BuildEmbeddingGradientShuffle(
const bool has_clip_grad, std::string* cur_rank_unique_embedding_grad_lbn) {
std::string update_embedding_grad_lbn = update_embedding_grad;
if (ctx->job_desc().enable_auto_mixed_precision()
&& ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16", true)) {
LogicalBlobId embedding_grad_lbi = GenLogicalBlobId(update_embedding_grad_lbn);
const OpNode* cast_node = op_graph.OpNode4OpName(embedding_grad_lbi.op_name());
if (cast_node->op().op_conf().has_user_conf()) {
const user_op::UserOpConfWrapper cast_op_conf(cast_node->op().op_conf());
if (cast_op_conf.op_type_name() == "cast") {
update_embedding_grad_lbn = cast_op_conf.input("in", 0);
job_builder->DelOps({cast_op_conf.op_name()});
}
}
&& !ParseBooleanFromEnv("ONEFLOW_ONE_EMBEDDING_GRADIENT_SHUFFLE_USE_FP16", true)) {
auto cast_op =
user_op::UserOpConfWrapperBuilder(embedding_op.op_name() + "_before_grad_shuffle_cast_h2f")
.Op("cast")
.Input("in", update_embedding_grad_lbn)
.Output("out")
.Attr<DataType>("dtype", DataType::kFloat)
.ScopeSymbolId(embedding_scope_symbol_id)
.Build();
job_builder->AddOps(embedding_parallel_conf, {cast_op.op_conf()});
update_embedding_grad_lbn = cast_op.output("out", 0);
}
if (use_system_gather) {
const int64_t num_segments =
Expand Down Expand Up @@ -914,9 +915,22 @@ Maybe<void> ReplaceEmbeddingOps::Apply(const OpGraph& op_graph, JobBuilder* job_
if (!op_conf.has_user_conf()) { return; }
if (!(op_conf.user_conf().op_type_name() == "embedding_lookup_placeholder")) { return; }
const user_op::UserOpConfWrapper embedding_op(op_node->op().op_conf());
const LogicalBlobId& lbi = GenLogicalBlobId(embedding_op.input("shadow", 0));
const std::string& shadow_op_name = lbi.op_name();
// assert all embeddings same placement
const OpNode* shadow_producer =
op_graph.OpNode4OpName(GenLogicalBlobId(embedding_op.input("shadow", 0)).op_name());
std::string shadow_op_name;
if (shadow_producer->op().op_conf().has_variable_conf()) {
shadow_op_name = shadow_producer->op().op_name();
} else if (shadow_producer->op().op_conf().has_user_conf()
&& shadow_producer->op().op_conf().user_conf().op_type_name() == "cast") {
const user_op::UserOpConfWrapper shadow_cast_op(shadow_producer->op().op_conf());
const OpNode* cast_producer =
op_graph.OpNode4OpName(GenLogicalBlobId(shadow_cast_op.input("in", 0)).op_name());
CHECK(cast_producer->op().op_conf().has_variable_conf()) << cast_producer->op().op_name();
shadow_op_name = cast_producer->op().op_name();
} else {
UNIMPLEMENTED() << "shadow must be variable or variable and cast";
}
// assume all embeddings have same placement
embedding_scope_symbol_id = embedding_op.op_conf().scope_symbol_id();
embedding_parallel_conf = op_node->parallel_desc().parallel_conf();

Expand Down Expand Up @@ -977,20 +991,8 @@ Maybe<void> ReplaceEmbeddingOps::Apply(const OpGraph& op_graph, JobBuilder* job_
const LogicalBlobId out = GenLogicalBlobId(embedding_op.output("embeddings", 0));
for (const OpEdge* out_edge : op_node->out_edges()) {
const OpNode* consumer = out_edge->dst_node();
if (consumer->op().op_conf().has_user_conf()
&& consumer->op().op_conf().user_conf().op_type_name() == "cast") {
const user_op::UserOpConfWrapper cast_op_conf(consumer->op().op_conf());
delete_op_names.push_back(consumer->op().op_name());
for (const OpEdge* cast_out_edge : consumer->out_edges()) {
const OpNode* cast_consumer = cast_out_edge->dst_node();
const LogicalBlobId cast_out_lbi = GenLogicalBlobId(cast_op_conf.output("out", 0));
UpdateConsumerOpConf(cast_consumer, cast_out_lbi, new_embeddings_lbn, &op_name2op_conf);
}
} else {
UpdateConsumerOpConf(consumer, out, new_embeddings_lbn, &op_name2op_conf);
}
UpdateConsumerOpConf(consumer, out, new_embeddings_lbn, &op_name2op_conf);
}

std::string state_initializer;
// find update op
const OpNode* producer =
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/ops/add_n_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace oneflow {
CHECK_NOTNULL_OR_RETURN(out);
for (const auto& pair : ctx->inputs()) {
const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second);
CHECK_EQ_OR_RETURN(in_0.data_type(), cur_in.data_type());
CHECK_EQ_OR_RETURN(in_0.data_type(), cur_in.data_type()) << ctx->op_name();
}
*out->mut_data_type() = in_0.data_type();
return Maybe<void>::Ok();
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/ops/one_embedding_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace oneflow {
}

/* static */ Maybe<void> EmbeddingLookupPlaceholderOp::InferDataType(user_op::InferContext* ctx) {
*ctx->OutputDType("embeddings", 0) = ctx->Attr<DataType>("dtype");
*ctx->OutputDType("embeddings", 0) = ctx->InputDType("shadow", 0);
return Maybe<void>::Ok();
}

Expand Down

0 comments on commit 944ad62

Please sign in to comment.