Skip to content

Commit

Permalink
Implement oneflow.embedding op (#8110)
Browse files Browse the repository at this point in the history
* impl embedding op

* refine

* fix the norm logic

* rename kernel

* refine

* impl in python

* refine

* add unit tests and fix bug

* refine

* refine kernel

* fix typo

* fix

* move functor

* use primitive memset

* fix int64

* refine renorm kernel and unit test

* fix typo

* add op dev states

* use check_graph

* fix graph backward

* refine unit test

* refine unit test

* add global test

* debug

* fix sparse unit test

* fix

* fix

* auto format by CI

* fix clang error

* fix unit test

* auto format by CI

* fix

* refine unit test

* refine kernel

* refine

* fix

* refine renorm kernel

* add sync

* refine

* rm test case(tmp)

* add unit test

* fix

* add sync

* try another test case

* auto format by CI

* fix

* auto format by CI

* remove autotest case for scale_by_grad

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 6, 2022
1 parent 445e236 commit 45cfcb5
Show file tree
Hide file tree
Showing 15 changed files with 1,286 additions and 87 deletions.
84 changes: 84 additions & 0 deletions oneflow/core/autograd/gradient_funcs/embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct EmbeddingCaptureState : public AutoGradCaptureState {
int64_t padding_idx = -1;
bool scale_grad_by_freq = false;
bool requires_grad = false;
};

class Embedding : public OpExprGradFunction<EmbeddingCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(EmbeddingCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const EmbeddingCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> Embedding::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "Forward op must be not null";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Embedding::Capture(EmbeddingCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));
ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1)));

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->padding_idx = JUST(composed_attrs.GetAttr<int64_t>("padding_idx"));
ctx->scale_grad_by_freq = JUST(composed_attrs.GetAttr<bool>("scale_grad_by_freq"));
return Maybe<void>::Ok();
}

Maybe<void> Embedding::Apply(const EmbeddingCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

in_grads->resize(ctx->SavedTensors().size());
const auto& weight = JUST(oneflow::VectorAt(ctx->SavedTensors(), 0));
const auto& indices = JUST(oneflow::VectorAt(ctx->SavedTensors(), 1));
int64_t padding_idx = ctx->padding_idx;
bool scale_grad_by_freq = ctx->scale_grad_by_freq;

JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::EmbeddingGrad(
JUST(oneflow::VectorAt(out_grads, 0)), weight, indices, padding_idx, scale_grad_by_freq));
return Maybe<void>::Ok();
}

REGISTER_OP_EXPR_GRAD_FUNCTION("embedding", Embedding);

} // namespace one
} // namespace oneflow
12 changes: 12 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,18 @@
signature: " Tensor (Tensor input, Int64 dim, Tensor index, Bool sparse_grad=False) => DimGather"
bind_python: True

- name: "embedding_renorm_"
signature: " Tensor (Tensor in, Tensor indices, Double max_norm, Double norm_type) => EmbeddingReNorm"
bind_python: True

- name: "embedding"
signature: " Tensor (Tensor weight, Tensor indices, Int64 padding_idx=None, Bool scale_grad_by_freq=False) => Embedding"
bind_python: True

- name: "embedding_grad"
signature: " Tensor (Tensor dy, Tensor weight, Tensor indices, Int64 padding_idx, Bool scale_grad_by_freq=False) => EmbeddingGrad"
bind_python: False

- name: "arg_sort"
signature: "Tensor (Tensor in, String direction) => ArgSort"
bind_python: True
Expand Down
53 changes: 53 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/error.h"
#include "oneflow/core/common/maybe.h"
Expand Down Expand Up @@ -206,6 +207,56 @@ class DeConv3dFunctor : public DeConvBaseFunctor {
}
};

class EmbeddingReNormFunctor {
public:
EmbeddingReNormFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("embedding_renorm").Input("in").Input("indices").Output("out").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in,
const std::shared_ptr<one::Tensor>& indices, const double& max_norm,
const double& norm_type) const {
CHECK_EQ_OR_RETURN(in->ndim(), 2)
<< Error::RuntimeError() << "The dimension of input should be 2.";
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
JUST(oneflow::VectorAt(*outputs, 0)) = in;

MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("max_norm", max_norm));
JUST(attrs.SetAttr<double>("norm_type", norm_type));

JUST(OpInterpUtil::Dispatch(*op_, {in, indices}, outputs.get(), attrs));
return JUST(oneflow::VectorAt(*outputs, 0));
}

private:
std::shared_ptr<OpExpr> op_;
};

class EmbeddingFunctor {
public:
EmbeddingFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("embedding").Input("weight").Input("indices").Output("out").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& weight,
const std::shared_ptr<one::Tensor>& indices,
const Optional<int64_t>& padding_idx,
const bool& scale_grad_by_freq) const {
CHECK_EQ_OR_RETURN(weight->ndim(), 2) << "The dimension of weight should be 2";
int64_t new_padding_idx = -1;
if (padding_idx.has_value()) { new_padding_idx = JUST(padding_idx); }
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("padding_idx", new_padding_idx));
JUST(attrs.SetAttr<bool>("scale_grad_by_freq", scale_grad_by_freq));
return OpInterpUtil::Dispatch<Tensor>(*op_, {weight, indices}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class MatMulFunctor {
public:
MatMulFunctor() {
Expand Down Expand Up @@ -3106,6 +3157,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::DeConv1dFunctor>("Deconv1d");
m.add_functor<impl::DeConv2dFunctor>("Deconv2d");
m.add_functor<impl::DeConv3dFunctor>("Deconv3d");
m.add_functor<impl::EmbeddingReNormFunctor>("EmbeddingReNorm");
m.add_functor<impl::EmbeddingFunctor>("Embedding");
m.add_functor<impl::MatMulFunctor>("MatMul");
m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul");
m.add_functor<impl::TensorDotFunctor>("TensorDot");
Expand Down
25 changes: 25 additions & 0 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,30 @@ class ConvDataGradFunctor {
std::shared_ptr<OpExpr> op_;
};

class EmbeddingGradFunctor {
public:
EmbeddingGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("embedding_grad")
.Input("dy")
.Input("weight")
.Input("indices")
.Output("dx")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& weight,
const std::shared_ptr<one::Tensor>& indices, const int64_t& padding_idx,
const bool& scale_grad_by_freq) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("padding_idx", padding_idx));
JUST(attrs.SetAttr<bool>("scale_grad_by_freq", scale_grad_by_freq));
return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, weight, indices}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class MaxPoolNdGradFunctor {
public:
MaxPoolNdGradFunctor() {
Expand Down Expand Up @@ -1013,6 +1037,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ConvBiasGradFunctor>("ConvBiasGrad");
m.add_functor<impl::ConvFilterGradFunctor>("ConvFilterGrad");
m.add_functor<impl::ConvDataGradFunctor>("ConvDataGrad");
m.add_functor<impl::EmbeddingGradFunctor>("EmbeddingGrad");
m.add_functor<impl::TFPoolNdGradFunctor>("TFPoolNdGrad");
m.add_functor<impl::AdaptivePoolNdGradFunctor>("AdaptivePoolNdGrad");
m.add_functor<impl::KLDivLossGradFunctor>("KLDivLossGrad");
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,8 @@ Maybe<void> JobBuildAndInferCtx::InferBlobBackwardSignature(
};
const auto& maybe_ok =
TRY(GenerateBackwardOpConfIf(op, &bw_op_confs, DiffLbi4BnInOp, LogicalBlobDesc4BnInOp));
CHECK(maybe_ok.IsOk() || maybe_ok.error()->has_gradient_function_not_found_error());
CHECK(maybe_ok.IsOk() || maybe_ok.error()->has_gradient_function_not_found_error())
<< GetFormatedSerializedError(::oneflow::private_details::JustGetError(maybe_ok));
// find backward used logical blob ids
auto backward_used_lbis = std::make_shared<HashSet<LogicalBlobId>>();
for (const auto& bw_op_conf : bw_op_confs) {
Expand Down
61 changes: 59 additions & 2 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2797,8 +2797,8 @@ def OneFlow_ImageResizeToFixedOp : OneFlow_BaseOp<"image_resize_to_fixed", [NoSi
#endif // GET_ONEFLOW_IMAGE_OP_DEFINITIONS

// Group: INDICES
// arg_sort, argmax, argwhere, batch_gather, dim_gather, dim_scatter_add, dim_scatter_add_like, dim_scatter_add_scalar, dim_scatter_mul, dim_scatter_mul_scalar, dim_scatter_update, dim_scatter_update_scalar, gather, gather_nd, generate_random_batch_permutation_indices, image_target_resize, logical_slice, scatter_nd, scatter_nd_like, slice, slice_grad, tensor_scatter_nd_add, tensor_scatter_nd_update, unsorted_batch_segment_sum, unsorted_segment_sum, unsorted_segment_sum_like, where, where_scalar_x, where_scalar_xy, where_scalar_y, median, searchsorted, searchsorted_scalar
// Total: 33
// arg_sort, argmax, argwhere, batch_gather, dim_gather, dim_scatter_add, dim_scatter_add_like, dim_scatter_add_scalar, dim_scatter_mul, dim_scatter_mul_scalar, dim_scatter_update, dim_scatter_update_scalar, embedding_renorm, embedding, embedding_grad, gather, gather_nd, generate_random_batch_permutation_indices, image_target_resize, logical_slice, scatter_nd, scatter_nd_like, slice, slice_grad, tensor_scatter_nd_add, tensor_scatter_nd_update, unsorted_batch_segment_sum, unsorted_segment_sum, unsorted_segment_sum_like, where, where_scalar_x, where_scalar_xy, where_scalar_y, median, searchsorted, searchsorted_scalar
// Total: 36

#ifdef GET_ONEFLOW_INDICES_OP_DEFINITIONS

Expand Down Expand Up @@ -3015,6 +3015,63 @@ def OneFlow_DimScatterUpdateScalarOp : OneFlow_BaseOp<"dim_scatter_update_scalar
let has_input_arg_modify_fn = 1;
}

def OneFlow_EmbeddingRenormOp : OneFlow_BaseOp<"embedding_renorm", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in,
OneFlow_Tensor:$indices
);
let output = (outs
OneFlow_Tensor:$out
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "0.">:$max_norm,
DefaultValuedAttr<F64Attr, "2.">:$norm_type
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_EmbeddingOp : OneFlow_BaseOp<"embedding", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$weight,
OneFlow_Tensor:$indices
);
let output = (outs
OneFlow_Tensor:$out
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "-1">:$padding_idx,
DefaultValuedAttr<BoolAttr, "false">:$scale_grad_by_freq
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
let has_input_arg_modify_fn = 1;
}

def OneFlow_EmbeddingGradOp : OneFlow_BaseOp<"embedding_grad", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
OneFlow_Tensor:$weight,
OneFlow_Tensor:$indices
);
let output = (outs
OneFlow_Tensor:$dx
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "-1">:$padding_idx,
DefaultValuedAttr<BoolAttr, "false">:$scale_grad_by_freq
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
let has_input_arg_modify_fn = 1;
}

def OneFlow_GatherOp : OneFlow_BaseOp<"gather", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in,
Expand Down
Loading

0 comments on commit 45cfcb5

Please sign in to comment.