Skip to content

Commit

Permalink
Refactor NLLLoss to support split class dim (#8380)
Browse files Browse the repository at this point in the history
* refactor

* RuntimeError

* avoid atomic add

* test

* fixes

* update test

* update test

* update test

* fix kernel

* improve backward

* update test

* out_weight to be required

* address static analysis errer

* fix static analysis error

* fix static analysis error

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
leaves-zwx and mergify[bot] authored Jun 18, 2022
1 parent f7532fd commit d7ef39f
Show file tree
Hide file tree
Showing 14 changed files with 768 additions and 502 deletions.
70 changes: 43 additions & 27 deletions oneflow/core/autograd/gradient_funcs/nll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,68 +15,84 @@ limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"

namespace oneflow {

namespace one {
struct NllCaptureState : public AutoGradCaptureState {

struct NLLCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
int64_t ignore_index = -100;
};

class Nll : public OpExprGradFunction<NllCaptureState> {
class NLLGradFunction : public OpExprGradFunction<NLLCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(NllCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
Maybe<void> Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const NllCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};
Maybe<void> Nll::Init(const OpExpr& op) {

Maybe<void> NLLGradFunction::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Nll::Capture(NllCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();

Maybe<void> NLLGradFunction::Capture(NLLCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
auto input = JUST(VectorAt(inputs, 0));
ctx->requires_grad = input->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->ignore_index = JUST(composed_attrs.GetAttr<int64_t>("ignore_index"));
ctx->SaveTensorForBackward(inputs.at(0)); // input
ctx->SaveTensorForBackward(inputs.at(1)); // target
ctx->SaveTensorForBackward(outputs.at(1)); // total_weight
ctx->SaveTensorForBackward(input); // input
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target
if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs.at(2)); // weight
ctx->SaveTensorForBackward(inputs[2]); // weight
}
return Maybe<void>::Ok();
}
Maybe<void> Nll::Apply(const NllCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {

Maybe<void> NLLGradFunction::Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

CHECK_EQ_OR_RETURN(out_grads.size(), 2);
const auto& dy = out_grads.at(0);
const auto& input = ctx->SavedTensors().at(0);
const auto& target = ctx->SavedTensors().at(1);
const auto& total_weight = ctx->SavedTensors().at(2);
CHECK_EQ_OR_RETURN(out_grads.size(), 2)
<< Error::RuntimeError() << "The number of out_grads is expected to be 2, got "
<< out_grads.size();
CHECK_GE_OR_RETURN(ctx->SavedTensors().size(), 2)
<< Error::RuntimeError()
<< "The number of saved tensors is expected to be greater than or equal to 2, got "
<< ctx->SavedTensors().size();
const auto& out_grad = out_grads[0];
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];

in_grads->resize(ctx->SavedTensors().size() - 1);
in_grads->resize(ctx->SavedTensors().size());

if (ctx->SavedTensors().size() == 4) {
const auto& weight = ctx->SavedTensors().at(3);
in_grads->at(0) =
JUST(functional::NllLossGrad(dy, input, target, weight, total_weight, ctx->ignore_index));
if (ctx->SavedTensors().size() == 2) {
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::NLLGrad(out_grad, input, target, NullOpt, ctx->ignore_index));
} else {
in_grads->at(0) =
JUST(functional::NllLossGrad(dy, input, target, NullOpt, total_weight, ctx->ignore_index));
// has weight
auto weight = JUST(VectorAt(ctx->SavedTensors(), 2));
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::NLLGrad(out_grad, input, target, weight, ctx->ignore_index));
}

return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("nll", Nll);

REGISTER_OP_EXPR_GRAD_FUNCTION("nll", NLLGradFunction);

} // namespace one

} // namespace oneflow
6 changes: 3 additions & 3 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1027,11 +1027,11 @@
bind_python: False

- name: "nll_loss"
signature: "Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index, String reduction) => NllLoss"
signature: "Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index, String reduction) => NLLLoss"
bind_python: True

- name: "nll_loss_grad"
signature: "Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None, Tensor total_target, Int64 ignore_index) => NllLossGrad"
- name: "nll_grad"
signature: "Tensor(Tensor out_grad, Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index) => NLLGrad"
bind_python: False

- name: "binary_cross_entropy_loss"
Expand Down
122 changes: 76 additions & 46 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1099,23 +1099,25 @@ class BinaryCrossEntropyWithLogitsLossFunctor : public LossFunctorBase {
std::shared_ptr<OpExpr> op_weight_pos_;
};

class NllLossFunctor {
class NLLLossFunctor {
public:
NllLossFunctor() {
NLLLossFunctor() {
op_ = CHECK_JUST(one::OpBuilder("nll")
.Input("input")
.Input("target")
.Output("out")
.Output("total_weight")
.Output("output")
.Output("out_weight")
.Build());

op_weight_ = CHECK_JUST(one::OpBuilder("nll")
.Input("input")
.Input("target")
.Input("weight")
.Output("out")
.Output("total_weight")
.Output("output")
.Output("out_weight")
.Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& target,
const Optional<one::Tensor>& weight, const int64_t& ignore_index,
Expand All @@ -1124,42 +1126,65 @@ class NllLossFunctor {
<< Error::RuntimeError() << "Reduction should be none, sum or mean.";

const auto& input_shape = input->shape();
const int64_t K = input_shape->NumAxes();
CHECK_GE_OR_RETURN(K, 2) << Error::RuntimeError() << "Expected 2 or more dimensions";
const int64_t N = input_shape->At(0);
const int64_t C = input_shape->At(1);

const auto& target_shape = target->shape();
CHECK_LE_OR_RETURN(input_shape->NumAxes(), 5)
<< Error::RuntimeError() << "The number of input's axis should be less equal to 5. ";
CHECK_EQ_OR_RETURN(input_shape->NumAxes() - 1, target_shape->NumAxes())
<< Error::RuntimeError()
<< "The number of input's axis should be equal to the number of target's axis - 1. ";
CHECK_EQ_OR_RETURN(target_shape->NumAxes(), K - 1)
<< Error::RuntimeError() << "Expected target dimensions (" << K - 1
<< ") to match input dimensions (" << K << "), got " << target_shape->NumAxes();
CHECK_EQ_OR_RETURN(target_shape->At(0), N)
<< Error::RuntimeError() << "Expected input batch_size (" << N
<< ") to match target batch_size (" << target_shape->At(0) << ")";

std::shared_ptr<one::Tensor> input_;
std::shared_ptr<one::Tensor> target_;
if (K > 2) {
DimVector idea_target_dim_vec;
idea_target_dim_vec.push_back(N);
for (int64_t i = 2; i < K; ++i) { idea_target_dim_vec.push_back(input_shape->At(i)); }
Shape idea_target_shape(idea_target_dim_vec);
CHECK_EQ_OR_RETURN(*target_shape, idea_target_shape)
<< Error::RuntimeError() << "Expected target shape " << idea_target_shape.ToString()
<< ", got " << target_shape->ToString();

std::vector<int> perm(input_shape->dim_vec().size(), 0);
perm[perm.size() - 1] = 1;
for (size_t i = 1; i < perm.size() - 1; ++i) { perm[i] = i + 1; }

input_ = JUST(sequence_function(functional::Transpose)
.then(std::bind(functional::Reshape, std::placeholders::_1, Shape({-1, C})))
.call(input, perm));
target_ = JUST(functional::Flatten(target, 0, K - 2));
} else {
input_ = input;
target_ = target;
}

MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("ignore_index", ignore_index));

std::vector<int> input_perm(input_shape->dim_vec().size(), 0);
input_perm[input_perm.size() - 1] = 1;
for (size_t i = 1; i < input_perm.size() - 1; ++i) { input_perm[i] = i + 1; }

const auto input_ = JUST(sequence_function(functional::Transpose)
.then(std::bind(functional::Reshape, std::placeholders::_1,
Shape({-1, input_shape->At(1)})))
.call(input, input_perm));
auto target_ = JUST(functional::Flatten(target, 0, target_shape->NumAxes() - 1));

std::shared_ptr<TensorTuple> kernel_result;
std::shared_ptr<Tensor> result;
std::shared_ptr<TensorTuple> nll_result;
if (weight) {
kernel_result = JUST(
nll_result = JUST(
OpInterpUtil::Dispatch<TensorTuple>(*op_weight_, {input_, target_, JUST(weight)}, attrs));
} else {
kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {input_, target_}, attrs));
nll_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {input_, target_}, attrs));
}
result = JUST(functional::Reshape(kernel_result->at(0), *target_shape));
if (reduction == "none") { return result; }
auto output = JUST(VectorAt(*nll_result, 0));

if (K > 2) { output = JUST(functional::Reshape(output, *target_shape)); }

if (reduction == "none") { return output; }

result = JUST(functional::ReduceSum(result, {}, false));
auto sum = JUST(functional::ReduceSum(output, {}, false));

if (reduction == "sum") { return result; }
if (reduction == "sum") { return sum; }

return functional::Div(result, kernel_result->at(1));
auto total_weight = JUST(functional::ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false));
return functional::Div(sum, total_weight);
}

private:
Expand All @@ -1171,18 +1196,20 @@ class CrossEntropyFunctor {
public:
CrossEntropyFunctor() {
op_log_softmax_ = CHECK_JUST(one::OpBuilder("log_softmax").Input("in").Output("prob").Build());

op_nll_ = CHECK_JUST(one::OpBuilder("nll")
.Input("input")
.Input("target")
.Output("out")
.Output("total_weight")
.Output("output")
.Output("out_weight")
.Build());

op_nll_weight_ = CHECK_JUST(one::OpBuilder("nll")
.Input("input")
.Input("target")
.Input("weight")
.Output("out")
.Output("total_weight")
.Output("output")
.Output("out_weight")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
Expand All @@ -1193,8 +1220,6 @@ class CrossEntropyFunctor {
<< Error::RuntimeError() << "Reduction should be none, sum or mean.";
const auto& input_shape = input->shape();
const auto& target_shape = target->shape();
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("ignore_index", ignore_index));

std::vector<int> input_perm(input_shape->dim_vec().size(), 0);
input_perm[input_perm.size() - 1] = 1;
Expand All @@ -1210,21 +1235,26 @@ class CrossEntropyFunctor {

const auto target_ = JUST(functional::Flatten(target, 0, target->shape()->NumAxes() - 1));

std::shared_ptr<TensorTuple> kernel_result;
std::shared_ptr<Tensor> result;
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("ignore_index", ignore_index));

std::shared_ptr<TensorTuple> nll_result;
if (weight) {
kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(
nll_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(
*op_nll_weight_, {input_, target_, JUST(weight)}, attrs));
} else {
kernel_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_nll_, {input_, target_}, attrs));
nll_result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_nll_, {input_, target_}, attrs));
}
result = JUST(functional::Reshape((*kernel_result)[0], *target_shape));
if (reduction == "none") { return result; }

result = JUST(functional::ReduceSum(result, {}, false));
if (reduction == "sum") { return result; }
auto output = JUST(VectorAt(*nll_result, 0));
output = JUST(functional::Reshape(output, *target_shape));
if (reduction == "none") { return output; }

auto sum = JUST(functional::ReduceSum(output, {}, false));
if (reduction == "sum") { return sum; }

return functional::Div(result, kernel_result->at(1));
auto total_weight = JUST(functional::ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false));
return functional::Div(sum, total_weight);
}

private:
Expand Down Expand Up @@ -3340,7 +3370,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::L1LossFunctor>("L1Loss");
m.add_functor<impl::MseLossFunctor>("MseLoss");
m.add_functor<impl::KLDivLossFunctor>("KLDivLoss");
m.add_functor<impl::NllLossFunctor>("NllLoss");
m.add_functor<impl::NLLLossFunctor>("NLLLoss");
m.add_functor<impl::BinaryCrossEntropyLossFunctor>("BinaryCrossEntropyLoss");
m.add_functor<impl::BinaryCrossEntropyWithLogitsLossFunctor>("BinaryCrossEntropyWithLogitsLoss");
m.add_functor<impl::SparseCrossEntropyFunctor>("SparseCrossEntropy");
Expand Down
30 changes: 14 additions & 16 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,39 +363,37 @@ class KLDivLossGradFunctor {
std::shared_ptr<OpExpr> op_;
};

class NllLossGradFunctor {
class NLLGradFunctor {
public:
NllLossGradFunctor() {
NLLGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("nll_grad")
.Input("out_grad")
.Input("input")
.Input("target")
.Input("total_weight")
.Input("dy")
.Output("dx")
.Output("in_grad")
.Build());

op_weight_ = CHECK_JUST(one::OpBuilder("nll_grad")
.Input("out_grad")
.Input("input")
.Input("target")
.Input("total_weight")
.Input("weight")
.Input("dy")
.Output("dx")
.Output("in_grad")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& out_grad,
const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& target,
const Optional<one::Tensor>& weight,
const std::shared_ptr<one::Tensor>& total_weight,
const int64_t ignore_index) const {
const Optional<one::Tensor>& weight, const int64_t ignore_index) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("ignore_index", ignore_index));

if (weight) {
return OpInterpUtil::Dispatch<one::Tensor>(
*op_weight_, {input, target, total_weight, JUST(weight), dy}, attrs);
return OpInterpUtil::Dispatch<one::Tensor>(*op_weight_,
{out_grad, input, target, JUST(weight)}, attrs);
} else {
return OpInterpUtil::Dispatch<one::Tensor>(*op_, {input, target, total_weight, dy}, attrs);
return OpInterpUtil::Dispatch<one::Tensor>(*op_, {out_grad, input, target}, attrs);
}
}

Expand Down Expand Up @@ -1120,7 +1118,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::TFPoolNdGradFunctor>("TFPoolNdGrad");
m.add_functor<impl::AdaptivePoolNdGradFunctor>("AdaptivePoolNdGrad");
m.add_functor<impl::KLDivLossGradFunctor>("KLDivLossGrad");
m.add_functor<impl::NllLossGradFunctor>("NllLossGrad");
m.add_functor<impl::NLLGradFunctor>("NLLGrad");
m.add_functor<impl::BinaryCrossEntropyLossGradFunctor>("BinaryCrossEntropyLossGrad");
m.add_functor<impl::BinaryCrossEntropyWithLogitsLossGradFunctor>(
"BinaryCrossEntropyWithLogitsLossGrad");
Expand Down
Loading

0 comments on commit d7ef39f

Please sign in to comment.