From d7ef39fd0c92558b07a9d2c75f522dc0605f6219 Mon Sep 17 00:00:00 2001 From: leaves-zwx Date: Sat, 18 Jun 2022 12:47:14 +0800 Subject: [PATCH] Refactor NLLLoss to support split class dim (#8380) * refactor * RuntimeError * avoid atomic add * test * fixes * update test * update test * update test * fix kernel * improve backward * update test * out_weight to be required * address static analysis errer * fix static analysis error * fix static analysis error Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- oneflow/core/autograd/gradient_funcs/nll.cpp | 70 +++-- oneflow/core/functional/functional_api.yaml | 6 +- oneflow/core/functional/impl/nn_functor.cpp | 122 +++++---- .../core/functional/impl/nn_grad_functor.cpp | 30 +-- ...t_sparse_softmax_cross_entropy_op_pass.cpp | 6 +- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 21 +- oneflow/user/kernels/nll_kernel.cpp | 254 +++++++++++------- oneflow/user/kernels/nll_kernel.cu | 207 -------------- oneflow/user/kernels/nll_kernel_util.cpp | 63 +++++ oneflow/user/kernels/nll_kernel_util.cu | 92 +++++++ oneflow/user/kernels/nll_kernel_util.h | 36 +++ oneflow/user/ops/nll_op.cpp | 227 ++++++++++------ python/oneflow/nn/modules/loss.py | 2 +- python/oneflow/test/modules/test_nll_loss.py | 134 +++++++++ 14 files changed, 768 insertions(+), 502 deletions(-) delete mode 100644 oneflow/user/kernels/nll_kernel.cu create mode 100644 oneflow/user/kernels/nll_kernel_util.cpp create mode 100644 oneflow/user/kernels/nll_kernel_util.cu create mode 100644 oneflow/user/kernels/nll_kernel_util.h create mode 100644 python/oneflow/test/modules/test_nll_loss.py diff --git a/oneflow/core/autograd/gradient_funcs/nll.cpp b/oneflow/core/autograd/gradient_funcs/nll.cpp index 20e1a67653c..430009b9dd2 100644 --- a/oneflow/core/autograd/gradient_funcs/nll.cpp +++ b/oneflow/core/autograd/gradient_funcs/nll.cpp @@ -15,68 +15,84 @@ limitations under the License. */ #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/functional/functional.h" +#include "oneflow/core/common/container_util.h" namespace oneflow { + namespace one { -struct NllCaptureState : public AutoGradCaptureState { + +struct NLLCaptureState : public AutoGradCaptureState { bool requires_grad = false; int64_t ignore_index = -100; }; -class Nll : public OpExprGradFunction { +class NLLGradFunction : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; - Maybe Capture(NllCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + Maybe Capture(NLLCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override; - Maybe Apply(const NllCaptureState* ctx, const TensorTuple& out_grads, + Maybe Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; private: AttrMap base_attrs_; }; -Maybe Nll::Init(const OpExpr& op) { + +Maybe NLLGradFunction::Init(const OpExpr& op) { const auto* fw_op_expr = dynamic_cast(&op); CHECK_NOTNULL_OR_RETURN(fw_op_expr); base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); return Maybe::Ok(); } -Maybe Nll::Capture(NllCaptureState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrMap& attrs) const { - ctx->requires_grad = inputs.at(0)->requires_grad(); + +Maybe NLLGradFunction::Capture(NLLCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + auto input = JUST(VectorAt(inputs, 0)); + ctx->requires_grad = input->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->ignore_index = JUST(composed_attrs.GetAttr("ignore_index")); - ctx->SaveTensorForBackward(inputs.at(0)); // input - ctx->SaveTensorForBackward(inputs.at(1)); // target - ctx->SaveTensorForBackward(outputs.at(1)); // total_weight + ctx->SaveTensorForBackward(input); // input + ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target if (inputs.size() == 3) { - ctx->SaveTensorForBackward(inputs.at(2)); // weight + ctx->SaveTensorForBackward(inputs[2]); // weight } return Maybe::Ok(); } -Maybe Nll::Apply(const NllCaptureState* ctx, const TensorTuple& out_grads, - TensorTuple* in_grads) const { + +Maybe NLLGradFunction::Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } - CHECK_EQ_OR_RETURN(out_grads.size(), 2); - const auto& dy = out_grads.at(0); - const auto& input = ctx->SavedTensors().at(0); - const auto& target = ctx->SavedTensors().at(1); - const auto& total_weight = ctx->SavedTensors().at(2); + CHECK_EQ_OR_RETURN(out_grads.size(), 2) + << Error::RuntimeError() << "The number of out_grads is expected to be 2, got " + << out_grads.size(); + CHECK_GE_OR_RETURN(ctx->SavedTensors().size(), 2) + << Error::RuntimeError() + << "The number of saved tensors is expected to be greater than or equal to 2, got " + << ctx->SavedTensors().size(); + const auto& out_grad = out_grads[0]; + const auto& input = ctx->SavedTensors()[0]; + const auto& target = ctx->SavedTensors()[1]; - in_grads->resize(ctx->SavedTensors().size() - 1); + in_grads->resize(ctx->SavedTensors().size()); - if (ctx->SavedTensors().size() == 4) { - const auto& weight = ctx->SavedTensors().at(3); - in_grads->at(0) = - JUST(functional::NllLossGrad(dy, input, target, weight, total_weight, ctx->ignore_index)); + if (ctx->SavedTensors().size() == 2) { + JUST(VectorAt(*in_grads, 0)) = + JUST(functional::NLLGrad(out_grad, input, target, NullOpt, ctx->ignore_index)); } else { - in_grads->at(0) = - JUST(functional::NllLossGrad(dy, input, target, NullOpt, total_weight, ctx->ignore_index)); + // has weight + auto weight = JUST(VectorAt(ctx->SavedTensors(), 2)); + JUST(VectorAt(*in_grads, 0)) = + JUST(functional::NLLGrad(out_grad, input, target, weight, ctx->ignore_index)); } + return Maybe::Ok(); } -REGISTER_OP_EXPR_GRAD_FUNCTION("nll", Nll); + +REGISTER_OP_EXPR_GRAD_FUNCTION("nll", NLLGradFunction); + } // namespace one + } // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index fe62eb5f858..aecca3fdf54 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1027,11 +1027,11 @@ bind_python: False - name: "nll_loss" - signature: "Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index, String reduction) => NllLoss" + signature: "Tensor(Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index, String reduction) => NLLLoss" bind_python: True -- name: "nll_loss_grad" - signature: "Tensor(Tensor dy, Tensor input, Tensor target, Tensor weight=None, Tensor total_target, Int64 ignore_index) => NllLossGrad" +- name: "nll_grad" + signature: "Tensor(Tensor out_grad, Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index) => NLLGrad" bind_python: False - name: "binary_cross_entropy_loss" diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index a453cdb4dfe..84edaf218a8 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -1099,23 +1099,25 @@ class BinaryCrossEntropyWithLogitsLossFunctor : public LossFunctorBase { std::shared_ptr op_weight_pos_; }; -class NllLossFunctor { +class NLLLossFunctor { public: - NllLossFunctor() { + NLLLossFunctor() { op_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") - .Output("out") - .Output("total_weight") + .Output("output") + .Output("out_weight") .Build()); + op_weight_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") .Input("weight") - .Output("out") - .Output("total_weight") + .Output("output") + .Output("out_weight") .Build()); } + Maybe operator()(const std::shared_ptr& input, const std::shared_ptr& target, const Optional& weight, const int64_t& ignore_index, @@ -1124,42 +1126,65 @@ class NllLossFunctor { << Error::RuntimeError() << "Reduction should be none, sum or mean."; const auto& input_shape = input->shape(); + const int64_t K = input_shape->NumAxes(); + CHECK_GE_OR_RETURN(K, 2) << Error::RuntimeError() << "Expected 2 or more dimensions"; + const int64_t N = input_shape->At(0); + const int64_t C = input_shape->At(1); + const auto& target_shape = target->shape(); - CHECK_LE_OR_RETURN(input_shape->NumAxes(), 5) - << Error::RuntimeError() << "The number of input's axis should be less equal to 5. "; - CHECK_EQ_OR_RETURN(input_shape->NumAxes() - 1, target_shape->NumAxes()) - << Error::RuntimeError() - << "The number of input's axis should be equal to the number of target's axis - 1. "; + CHECK_EQ_OR_RETURN(target_shape->NumAxes(), K - 1) + << Error::RuntimeError() << "Expected target dimensions (" << K - 1 + << ") to match input dimensions (" << K << "), got " << target_shape->NumAxes(); + CHECK_EQ_OR_RETURN(target_shape->At(0), N) + << Error::RuntimeError() << "Expected input batch_size (" << N + << ") to match target batch_size (" << target_shape->At(0) << ")"; + + std::shared_ptr input_; + std::shared_ptr target_; + if (K > 2) { + DimVector idea_target_dim_vec; + idea_target_dim_vec.push_back(N); + for (int64_t i = 2; i < K; ++i) { idea_target_dim_vec.push_back(input_shape->At(i)); } + Shape idea_target_shape(idea_target_dim_vec); + CHECK_EQ_OR_RETURN(*target_shape, idea_target_shape) + << Error::RuntimeError() << "Expected target shape " << idea_target_shape.ToString() + << ", got " << target_shape->ToString(); + + std::vector perm(input_shape->dim_vec().size(), 0); + perm[perm.size() - 1] = 1; + for (size_t i = 1; i < perm.size() - 1; ++i) { perm[i] = i + 1; } + + input_ = JUST(sequence_function(functional::Transpose) + .then(std::bind(functional::Reshape, std::placeholders::_1, Shape({-1, C}))) + .call(input, perm)); + target_ = JUST(functional::Flatten(target, 0, K - 2)); + } else { + input_ = input; + target_ = target; + } MutableAttrMap attrs; JUST(attrs.SetAttr("ignore_index", ignore_index)); - std::vector input_perm(input_shape->dim_vec().size(), 0); - input_perm[input_perm.size() - 1] = 1; - for (size_t i = 1; i < input_perm.size() - 1; ++i) { input_perm[i] = i + 1; } - - const auto input_ = JUST(sequence_function(functional::Transpose) - .then(std::bind(functional::Reshape, std::placeholders::_1, - Shape({-1, input_shape->At(1)}))) - .call(input, input_perm)); - auto target_ = JUST(functional::Flatten(target, 0, target_shape->NumAxes() - 1)); - - std::shared_ptr kernel_result; - std::shared_ptr result; + std::shared_ptr nll_result; if (weight) { - kernel_result = JUST( + nll_result = JUST( OpInterpUtil::Dispatch(*op_weight_, {input_, target_, JUST(weight)}, attrs)); } else { - kernel_result = JUST(OpInterpUtil::Dispatch(*op_, {input_, target_}, attrs)); + nll_result = JUST(OpInterpUtil::Dispatch(*op_, {input_, target_}, attrs)); } - result = JUST(functional::Reshape(kernel_result->at(0), *target_shape)); - if (reduction == "none") { return result; } + auto output = JUST(VectorAt(*nll_result, 0)); + + if (K > 2) { output = JUST(functional::Reshape(output, *target_shape)); } + + if (reduction == "none") { return output; } - result = JUST(functional::ReduceSum(result, {}, false)); + auto sum = JUST(functional::ReduceSum(output, {}, false)); - if (reduction == "sum") { return result; } + if (reduction == "sum") { return sum; } - return functional::Div(result, kernel_result->at(1)); + auto total_weight = JUST(functional::ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false)); + return functional::Div(sum, total_weight); } private: @@ -1171,18 +1196,20 @@ class CrossEntropyFunctor { public: CrossEntropyFunctor() { op_log_softmax_ = CHECK_JUST(one::OpBuilder("log_softmax").Input("in").Output("prob").Build()); + op_nll_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") - .Output("out") - .Output("total_weight") + .Output("output") + .Output("out_weight") .Build()); + op_nll_weight_ = CHECK_JUST(one::OpBuilder("nll") .Input("input") .Input("target") .Input("weight") - .Output("out") - .Output("total_weight") + .Output("output") + .Output("out_weight") .Build()); } Maybe operator()(const std::shared_ptr& input, @@ -1193,8 +1220,6 @@ class CrossEntropyFunctor { << Error::RuntimeError() << "Reduction should be none, sum or mean."; const auto& input_shape = input->shape(); const auto& target_shape = target->shape(); - MutableAttrMap attrs; - JUST(attrs.SetAttr("ignore_index", ignore_index)); std::vector input_perm(input_shape->dim_vec().size(), 0); input_perm[input_perm.size() - 1] = 1; @@ -1210,21 +1235,26 @@ class CrossEntropyFunctor { const auto target_ = JUST(functional::Flatten(target, 0, target->shape()->NumAxes() - 1)); - std::shared_ptr kernel_result; - std::shared_ptr result; + MutableAttrMap attrs; + JUST(attrs.SetAttr("ignore_index", ignore_index)); + + std::shared_ptr nll_result; if (weight) { - kernel_result = JUST(OpInterpUtil::Dispatch( + nll_result = JUST(OpInterpUtil::Dispatch( *op_nll_weight_, {input_, target_, JUST(weight)}, attrs)); } else { - kernel_result = JUST(OpInterpUtil::Dispatch(*op_nll_, {input_, target_}, attrs)); + nll_result = JUST(OpInterpUtil::Dispatch(*op_nll_, {input_, target_}, attrs)); } - result = JUST(functional::Reshape((*kernel_result)[0], *target_shape)); - if (reduction == "none") { return result; } - result = JUST(functional::ReduceSum(result, {}, false)); - if (reduction == "sum") { return result; } + auto output = JUST(VectorAt(*nll_result, 0)); + output = JUST(functional::Reshape(output, *target_shape)); + if (reduction == "none") { return output; } + + auto sum = JUST(functional::ReduceSum(output, {}, false)); + if (reduction == "sum") { return sum; } - return functional::Div(result, kernel_result->at(1)); + auto total_weight = JUST(functional::ReduceSum(JUST(VectorAt(*nll_result, 1)), {}, false)); + return functional::Div(sum, total_weight); } private: @@ -3340,7 +3370,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("L1Loss"); m.add_functor("MseLoss"); m.add_functor("KLDivLoss"); - m.add_functor("NllLoss"); + m.add_functor("NLLLoss"); m.add_functor("BinaryCrossEntropyLoss"); m.add_functor("BinaryCrossEntropyWithLogitsLoss"); m.add_functor("SparseCrossEntropy"); diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index 8e43b83ddb1..5689710ac2b 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -363,39 +363,37 @@ class KLDivLossGradFunctor { std::shared_ptr op_; }; -class NllLossGradFunctor { +class NLLGradFunctor { public: - NllLossGradFunctor() { + NLLGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("nll_grad") + .Input("out_grad") .Input("input") .Input("target") - .Input("total_weight") - .Input("dy") - .Output("dx") + .Output("in_grad") .Build()); + op_weight_ = CHECK_JUST(one::OpBuilder("nll_grad") + .Input("out_grad") .Input("input") .Input("target") - .Input("total_weight") .Input("weight") - .Input("dy") - .Output("dx") + .Output("in_grad") .Build()); } - Maybe operator()(const std::shared_ptr& dy, + + Maybe operator()(const std::shared_ptr& out_grad, const std::shared_ptr& input, const std::shared_ptr& target, - const Optional& weight, - const std::shared_ptr& total_weight, - const int64_t ignore_index) const { + const Optional& weight, const int64_t ignore_index) const { MutableAttrMap attrs; JUST(attrs.SetAttr("ignore_index", ignore_index)); if (weight) { - return OpInterpUtil::Dispatch( - *op_weight_, {input, target, total_weight, JUST(weight), dy}, attrs); + return OpInterpUtil::Dispatch(*op_weight_, + {out_grad, input, target, JUST(weight)}, attrs); } else { - return OpInterpUtil::Dispatch(*op_, {input, target, total_weight, dy}, attrs); + return OpInterpUtil::Dispatch(*op_, {out_grad, input, target}, attrs); } } @@ -1120,7 +1118,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("TFPoolNdGrad"); m.add_functor("AdaptivePoolNdGrad"); m.add_functor("KLDivLossGrad"); - m.add_functor("NllLossGrad"); + m.add_functor("NLLGrad"); m.add_functor("BinaryCrossEntropyLossGrad"); m.add_functor( "BinaryCrossEntropyWithLogitsLossGrad"); diff --git a/oneflow/core/job_rewriter/split_sparse_softmax_cross_entropy_op_pass.cpp b/oneflow/core/job_rewriter/split_sparse_softmax_cross_entropy_op_pass.cpp index 19851e21852..e9a0211ea62 100644 --- a/oneflow/core/job_rewriter/split_sparse_softmax_cross_entropy_op_pass.cpp +++ b/oneflow/core/job_rewriter/split_sparse_softmax_cross_entropy_op_pass.cpp @@ -213,8 +213,8 @@ Maybe SplitSparseSoftmaxCrossEntropyOpPass::Apply(const OpGraph& op_graph, .Op("nll") .Input("input", broadcast_sub_op.output("z", 0)) .Input("target", op_label_blob_name) - .Output("out") - .Output("total_weight") + .Output("output") + .Output("out_weight") .Attr("ignore_index", -100) .ScopeSymbolId(scope_symbol_id) .Build(); @@ -223,7 +223,7 @@ Maybe SplitSparseSoftmaxCrossEntropyOpPass::Apply(const OpGraph& op_graph, const std::string& prob_lbn = cur_op.output("prob", 0); const std::string& out_lbn = cur_op.output("out", 0); const std::string& new_prob_lbn = broadcast_div_op.output("z", 0); - const std::string& new_out_lbn = nll_op.output("out", 0); + const std::string& new_out_lbn = nll_op.output("output", 0); for (const OpEdge* out_edge : node->out_edges()) { const OpNode* consumer = out_edge->dst_node(); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 60d13342c1e..44a1861912c 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -4969,44 +4969,41 @@ def OneFlow_LocalMultiReduceMinAbsOp : OneFlow_BaseOp<"local_multi_reduce_min_ab let has_get_sbp_fn = 1; } -def OneFlow_NllOp : OneFlow_BaseOp<"nll", [NoSideEffect, DeclareOpInterfaceMethods]> { +def OneFlow_NLLOp : OneFlow_BaseOp<"nll", [NoSideEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$input, OneFlow_Tensor:$target, Optional:$weight ); let output = (outs - OneFlow_Tensor:$out, - OneFlow_Tensor:$total_weight + OneFlow_Tensor:$output, + OneFlow_Tensor:$out_weight ); let attrs = (ins DefaultValuedAttr:$ignore_index ); + let has_data_type_infer_fn = 1; let has_logical_tensor_desc_infer_fn = 1; - let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; - let has_data_type_infer_fn = 1; let has_input_arg_modify_fn = 1; } -def OneFlow_NllGradOp : OneFlow_BaseOp<"nll_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { +def OneFlow_NLLGradOp : OneFlow_BaseOp<"nll_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { let input = (ins + OneFlow_Tensor:$out_grad, OneFlow_Tensor:$input, OneFlow_Tensor:$target, - OneFlow_Tensor:$total_weight, - Optional:$weight, - OneFlow_Tensor:$dy + Optional:$weight ); let output = (outs - OneFlow_Tensor:$dx + OneFlow_Tensor:$in_grad ); let attrs = (ins DefaultValuedAttr:$ignore_index ); + let has_data_type_infer_fn = 1; let has_logical_tensor_desc_infer_fn = 1; - let has_physical_tensor_desc_infer_fn = 1; let has_get_sbp_fn = 1; - let has_data_type_infer_fn = 1; } def OneFlow_PowXGradOp : OneFlow_BaseOp<"pow_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { diff --git a/oneflow/user/kernels/nll_kernel.cpp b/oneflow/user/kernels/nll_kernel.cpp index f71df661167..01abf5565b1 100644 --- a/oneflow/user/kernels/nll_kernel.cpp +++ b/oneflow/user/kernels/nll_kernel.cpp @@ -14,130 +14,180 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/core/ndarray/ndarray_util.h" -#include "oneflow/user/kernels/loss_kernel_util.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/job/nd_sbp_util.h" +#include "oneflow/user/kernels/nll_kernel_util.h" namespace oneflow { -namespace user_op { + namespace { -using namespace loss; - -template -void ComputeNllOut(int64_t num_instances, K num_classes, K ignore_index, const T* input, - const K* target, T* out, const T* weight, T* total_weight) { - *total_weight = 0; - FOR_RANGE(int64_t, i, 0, num_instances) { - K label = target[i]; - if (label == ignore_index) { - out[i] = 0; - continue; +class NLLKernelCache final : public user_op::OpKernelCache { + public: + NLLKernelCache(int64_t class_start, int64_t num_classes) + : class_start_(class_start), num_classes_(num_classes) {} + ~NLLKernelCache() override = default; + + int64_t class_start() const { return class_start_; } + int64_t num_classes() const { return num_classes_; } + + private: + const int64_t class_start_; + const int64_t num_classes_; +}; + +std::shared_ptr CreateNLLKernelCache(user_op::KernelCacheContext* ctx) { + CHECK_GT(ctx->parallel_ctx().parallel_num(), 0) << ctx->op_name() << ": invalid parallel_ctx"; + if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; } + + const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("input", 0); + const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); + CHECK_EQ(nd_sbp.sbp_parallel_size(), hierarchy.NumAxes()) + << ctx->op_name() << ": Expected input sbp " << NdSbpToString(nd_sbp) << " match hierarchy " + << hierarchy.ToString(); + + const Shape& shape = ctx->LogicalTensorDesc4ArgNameAndIndex("input", 0)->shape(); + const int64_t class_axis = shape.NumAxes() - 1; + + bool split_class_dim = false; + for (const auto& sbp : nd_sbp.sbp_parallel()) { + if (sbp.has_split_parallel() && sbp.split_parallel().axis() == class_axis) { + split_class_dim = true; + break; } - CHECK_GE(label, 0); - CHECK_LT(label, num_classes); - T cur_weight = weight == nullptr ? 1 : weight[label]; - *total_weight += cur_weight; - out[i] = -input[i * num_classes + label] * cur_weight; - } -} -template -void ComputeNllGradOut(int64_t num_instances, K num_classes, K ignore_index, const K* target, - const T* dy, T* dx, const T* weight, const T* total_weight) { - FOR_RANGE(int64_t, i, 0, num_instances) { - K label = target[i]; - if (label == ignore_index) { continue; } - CHECK_GE(label, 0); - CHECK_LT(label, num_classes); - T cur_weight = weight == nullptr ? -1 : -weight[label]; - dx[i * num_classes + label] = dy[i] * cur_weight; } + + if (!split_class_dim) { return nullptr; } + + TensorSliceView view = + GetTensorSliceView4ParallelId(hierarchy, nd_sbp, shape, ctx->parallel_ctx().parallel_id()); + return std::make_shared(view.At(class_axis).begin(), view.At(class_axis).size()); } -template -class NllKernel final : public user_op::OpKernel { + +} // namespace + +template +class NLLKernel final : public user_op::OpKernel { public: - NllKernel() = default; - ~NllKernel() = default; + NLLKernel() = default; + ~NLLKernel() override = default; + + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateNLLKernelCache(ctx); + } private: - using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); - const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); - auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); - auto* total_weight_blob = ctx->Tensor4ArgNameAndIndex("total_weight", 0); - - const int64_t num_instances = target_blob->shape().elem_cnt(); - CHECK_EQ(input_blob->shape().elem_cnt() % num_instances, 0); - const K num_classes = static_cast(input_blob->shape().elem_cnt() / num_instances); + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache* cache) const override { + const auto* input = ctx->Tensor4ArgNameAndIndex("input", 0); + const auto* target = ctx->Tensor4ArgNameAndIndex("target", 0); + auto* output = ctx->Tensor4ArgNameAndIndex("output", 0); + auto* out_weight = ctx->Tensor4ArgNameAndIndex("out_weight", 0); + + const int64_t N = target->shape().elem_cnt(); + const int64_t C = input->shape().At(input->shape().NumAxes() - 1); + CHECK_LE(N, std::numeric_limits::max()) + << "Expected batch size not exceed int32 numeric limits"; + + K class_start = 0; + if (cache) { + const auto* spec_cache = dynamic_cast(cache); + CHECK_NOTNULL(spec_cache); + CHECK_EQ(spec_cache->num_classes(), C) << ctx->op_name() << ": expected num_classes " << C + << ", got " << spec_cache->num_classes(); + class_start = spec_cache->class_start(); + } + const K ignore_index = static_cast(ctx->Attr("ignore_index")); - const T* input = input_blob->dptr(); - const K* target = target_blob->dptr(); - T* out = out_blob->mut_dptr(); - T* total_weight = total_weight_blob->mut_dptr(); - const T* weight = - ctx->has_input("weight", 0) ? ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr() : nullptr; + const T* weight_dptr = nullptr; + if (ctx->has_input("weight", 0)) { + weight_dptr = CHECK_NOTNULL(ctx->Tensor4ArgNameAndIndex("weight", 0))->dptr(); + } - ComputeNllOut(num_instances, num_classes, ignore_index, input, target, out, weight, - total_weight); + NLLKernelUtil::Forward(ctx->stream(), static_cast(N), + static_cast(C), class_start, ignore_index, + input->dptr(), target->dptr(), weight_dptr, + output->mut_dptr(), out_weight->mut_dptr()); } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -template -class NllGradKernel final : public user_op::OpKernel { +template +class NLLGradKernel final : public user_op::OpKernel { public: - NllGradKernel() = default; - ~NllGradKernel() = default; + NLLGradKernel() = default; + ~NLLGradKernel() override = default; + + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateNLLKernelCache(ctx); + } private: - using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); - const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); - const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); - auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); - auto* total_weight_blob = ctx->Tensor4ArgNameAndIndex("total_weight", 0); - - const int64_t num_instances = target_blob->shape().elem_cnt(); - const int64_t input_elem_cnt = input_blob->shape().elem_cnt(); - CHECK_EQ(input_elem_cnt % num_instances, 0); - const K num_classes = static_cast(input_elem_cnt / num_instances); + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache* cache) const override { + const auto* target = ctx->Tensor4ArgNameAndIndex("target", 0); + const auto* out_grad = ctx->Tensor4ArgNameAndIndex("out_grad", 0); + auto* in_grad = ctx->Tensor4ArgNameAndIndex("in_grad", 0); + + const int64_t N = target->shape().elem_cnt(); + const int64_t C = in_grad->shape().At(in_grad->shape().NumAxes() - 1); + CHECK_LE(N, std::numeric_limits::max()) + << "Expected batch size not exceed int32 numeric limits"; + + K class_start = 0; + if (cache) { + const auto* spec_cache = dynamic_cast(cache); + CHECK_NOTNULL(spec_cache); + CHECK_EQ(spec_cache->num_classes(), C) << ctx->op_name() << ": expected num_classes " << C + << ", got " << spec_cache->num_classes(); + class_start = spec_cache->class_start(); + } + const K ignore_index = static_cast(ctx->Attr("ignore_index")); - const T* dy = dy_blob->dptr(); - const K* target = target_blob->dptr(); - const T* total_weight = total_weight_blob->dptr(); - T* dx = dx_blob->mut_dptr(); - const T* weight = - ctx->has_input("weight", 0) ? ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr() : nullptr; - Memset(ctx->stream(), dx, 0, GetCudaAlignedSize(input_elem_cnt * sizeof(T))); - ComputeNllGradOut(num_instances, num_classes, ignore_index, target, dy, dx, weight, - total_weight); + const T* weight_dptr = nullptr; + if (ctx->has_input("weight", 0)) { + weight_dptr = CHECK_NOTNULL(ctx->Tensor4ArgNameAndIndex("weight", 0))->dptr(); + } + + NLLKernelUtil::Backward( + ctx->stream(), static_cast(N), static_cast(C), class_start, ignore_index, + out_grad->dptr(), target->dptr(), weight_dptr, in_grad->mut_dptr()); } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; -} // namespace -#define REGISTER_NLL_KERNEL(dtype_pair, ltype_pair) \ - REGISTER_USER_KERNEL("nll") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ - && (user_op::HobDataType("target", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ - && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))); - -#define REGISTER_NLL_GRAD_KERNEL(dtype_pair, ltype_pair) \ - REGISTER_USER_KERNEL("nll_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ - && (user_op::HobDataType("target", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ - && (user_op::HobDataType("dy", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ - && (user_op::HobDataType("dx", 0) == OF_PP_PAIR_SECOND(dtype_pair))); - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_NLL_KERNEL, FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_NLL_GRAD_KERNEL, FLOATING_DATA_TYPE_SEQ, - INDEX_DATA_TYPE_SEQ) -} // namespace user_op +#define REGISTER_NLL_KERNELS(device, dtype, ltype) \ + REGISTER_USER_KERNEL("nll").SetCreateFn>().SetIsMatchedHob( \ + (user_op::HobDeviceType() == device) \ + && (user_op::HobDataType("input", 0) == GetDataType::value) \ + && (user_op::HobDataType("target", 0) == GetDataType::value)); \ + REGISTER_USER_KERNEL("nll_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == device) \ + && (user_op::HobDataType("input", 0) == GetDataType::value) \ + && (user_op::HobDataType("target", 0) == GetDataType::value) \ + && (user_op::HobDataType("out_grad", 0) == GetDataType::value)) + +REGISTER_NLL_KERNELS(DeviceType::kCPU, float, int32_t); +REGISTER_NLL_KERNELS(DeviceType::kCPU, float, int64_t); +REGISTER_NLL_KERNELS(DeviceType::kCPU, double, int32_t); +REGISTER_NLL_KERNELS(DeviceType::kCPU, double, int64_t); + +#ifdef WITH_CUDA + +REGISTER_NLL_KERNELS(DeviceType::kCUDA, float, int32_t); +REGISTER_NLL_KERNELS(DeviceType::kCUDA, float, int64_t); +REGISTER_NLL_KERNELS(DeviceType::kCUDA, double, int32_t); +REGISTER_NLL_KERNELS(DeviceType::kCUDA, double, int64_t); +REGISTER_NLL_KERNELS(DeviceType::kCUDA, half, int32_t); +REGISTER_NLL_KERNELS(DeviceType::kCUDA, half, int64_t); + +#endif // WITH_CUDA + } // namespace oneflow diff --git a/oneflow/user/kernels/nll_kernel.cu b/oneflow/user/kernels/nll_kernel.cu deleted file mode 100644 index 9e78cf52257..00000000000 --- a/oneflow/user/kernels/nll_kernel.cu +++ /dev/null @@ -1,207 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include -#include "oneflow/core/cuda/atomic.cuh" -#include "oneflow/core/framework/framework.h" -#include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/user/kernels/loss_kernel_util.h" -#include "oneflow/core/ep/cuda/cuda_stream.h" - -namespace oneflow { -namespace user_op { -namespace { - -using namespace loss; - -#define RETURN_VOID_IF_NOT_HALF typename std::enable_if_t::value, void> -#define RETURN_VOID_IF_HALF typename std::enable_if_t::value, void> - -template -__global__ RETURN_VOID_IF_NOT_HALF ComputeNllOutNone(const int64_t num_instances, - const K num_classes, const K ignore_index, - const T* input, const K* target, T* out, - const T* weight, T* total_weight) { - const T zero_val = GetZeroVal(); - const T one_val = GetOneVal(); - CUDA_1D_KERNEL_LOOP(i, num_instances) { - K label = target[i]; - if (label == ignore_index) { - out[i] = zero_val; - continue; - } - assert(label >= 0); - assert(label < num_classes); - const T cur_weight = weight == nullptr ? one_val : weight[label]; - cuda::atomic::Add(total_weight, cur_weight); - out[i] = -input[i * num_classes + label] * cur_weight; - } -} - -template -__global__ RETURN_VOID_IF_HALF ComputeNllOutNone(const int64_t num_instances, const K num_classes, - const K ignore_index, const T* input, - const K* target, T* out, const T* weight, - T* total_weight) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - const T zero_val = __float2half(0.0); - const T one_val = __float2half(1.0); - CUDA_1D_KERNEL_LOOP(i, num_instances) { - K label = target[i]; - if (label == ignore_index) { - out[i] = zero_val; - continue; - } - assert(label >= 0); - assert(label < num_classes); - const half cur_weight = weight == nullptr ? one_val : weight[label]; - cuda::atomic::Add(total_weight, cur_weight); - out[i] = __float2half(-__half2float(input[i * num_classes + label] * cur_weight)); - } -#else - printf("use half need nvcc arch >= 530"); - assert(false); -#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ -} - -template -__global__ RETURN_VOID_IF_NOT_HALF ComputeNllGradOut(const int64_t num_instances, - const K num_classes, const K ignore_index, - const K* target, const T* dy, T* dx, - const T* weight, const T* total_weight) { - CUDA_1D_KERNEL_LOOP(i, num_instances) { - K label = target[i]; - if (label == ignore_index) { continue; } - assert(label >= 0); - assert(label < num_classes); - const T cur_weight = weight == nullptr ? -GetOneVal() : -weight[label]; - dx[i * num_classes + label] = dy[i] * cur_weight; - } -} - -template -__global__ RETURN_VOID_IF_HALF ComputeNllGradOut(const int64_t num_instances, const K num_classes, - const K ignore_index, const K* target, const T* dy, - T* dx, const T* weight, const T* total_weight) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - CUDA_1D_KERNEL_LOOP(i, num_instances) { - K label = target[i]; - if (label == ignore_index) { continue; } - assert(label >= 0); - assert(label < num_classes); - const half cur_weight = weight == nullptr ? __float2half(-1.0) : __hneg(weight[label]); - dx[i * num_classes + label] = __hmul(dy[i], cur_weight); - } -#else - printf("use half need nvcc arch >= 530"); - assert(false); -#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ -} - -template -class NllKernel final : public user_op::OpKernel { - public: - NllKernel() = default; - ~NllKernel() = default; - - private: - using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); - const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); - auto* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); - auto* total_weight_blob = ctx->Tensor4ArgNameAndIndex("total_weight", 0); - - const int64_t num_instances = target_blob->shape().elem_cnt(); - CHECK_EQ(input_blob->shape().elem_cnt() % num_instances, 0); - const K num_classes = static_cast(input_blob->shape().elem_cnt() / num_instances); - const K ignore_index = static_cast(ctx->Attr("ignore_index")); - - const T* input = input_blob->dptr(); - const K* target = target_blob->dptr(); - T* out = out_blob->mut_dptr(); - T* total_weight = total_weight_blob->mut_dptr(); - const T* weight = - ctx->has_input("weight", 0) ? ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr() : nullptr; - Memset(ctx->stream(), total_weight, 0, sizeof(T)); - - ComputeNllOutNone<<stream()->As()->cuda_stream()>>>( - num_instances, num_classes, ignore_index, input, target, out, weight, total_weight); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -template -class NllGradKernel final : public user_op::OpKernel { - public: - NllGradKernel() = default; - ~NllGradKernel() = default; - - private: - using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto* input_blob = ctx->Tensor4ArgNameAndIndex("input", 0); - const auto* target_blob = ctx->Tensor4ArgNameAndIndex("target", 0); - const auto* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0); - auto* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0); - auto* total_weight_blob = ctx->Tensor4ArgNameAndIndex("total_weight", 0); - - const int64_t num_instances = target_blob->shape().elem_cnt(); - const int64_t input_elem_cnt = input_blob->shape().elem_cnt(); - CHECK_EQ(input_elem_cnt % num_instances, 0); - const K num_classes = static_cast(input_elem_cnt / num_instances); - const K ignore_index = static_cast(ctx->Attr("ignore_index")); - - const T* dy = dy_blob->dptr(); - const K* target = target_blob->dptr(); - const T* total_weight = total_weight_blob->dptr(); - T* dx = dx_blob->mut_dptr(); - const T* weight = - ctx->has_input("weight", 0) ? ctx->Tensor4ArgNameAndIndex("weight", 0)->dptr() : nullptr; - - Memset(ctx->stream(), dx, 0, input_elem_cnt * sizeof(T)); - - ComputeNllGradOut<<stream()->As()->cuda_stream()>>>( - num_instances, num_classes, ignore_index, target, dy, dx, weight, total_weight); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -} // namespace -#define REGISTER_NLL_KERNEL(dtype_pair, ltype_pair) \ - REGISTER_USER_KERNEL("nll") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ - && (user_op::HobDataType("target", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ - && (user_op::HobDataType("out", 0) == OF_PP_PAIR_SECOND(dtype_pair))); - -#define REGISTER_NLL_GRAD_KERNEL(dtype_pair, ltype_pair) \ - REGISTER_USER_KERNEL("nll_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ - && (user_op::HobDataType("target", 0) == OF_PP_PAIR_SECOND(ltype_pair)) \ - && (user_op::HobDataType("dy", 0) == OF_PP_PAIR_SECOND(dtype_pair)) \ - && (user_op::HobDataType("dx", 0) == OF_PP_PAIR_SECOND(dtype_pair))); - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_NLL_KERNEL, FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, - INDEX_DATA_TYPE_SEQ) - -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_NLL_GRAD_KERNEL, - FLOATING_DATA_TYPE_SEQ HALF_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) - -} // namespace user_op -} // namespace oneflow diff --git a/oneflow/user/kernels/nll_kernel_util.cpp b/oneflow/user/kernels/nll_kernel_util.cpp new file mode 100644 index 00000000000..bbaf4265975 --- /dev/null +++ b/oneflow/user/kernels/nll_kernel_util.cpp @@ -0,0 +1,63 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/nll_kernel_util.h" + +namespace oneflow { + +template +struct NLLKernelUtil { + static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes, + const K class_start, const K ignore_index, const T* input, const K* target, + const T* weight, T* out, T* out_weight) { + FOR_RANGE(int32_t, i, 0, num_samples) { + K label = target[i]; + T w = T{0}; + T y = T{0}; + if (label != ignore_index) { + label -= class_start; + if (label >= 0 && label < num_classes) { + w = weight ? weight[label] : T{1}; + y = -(input[i * num_classes + label] * w); + } + } + out[i] = y; + out_weight[i] = w; + } + } + + static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes, + const K class_start, const K ignore_index, const T* out_grad, + const K* target, const T* weight, T* in_grad) { + Memset(stream, in_grad, 0, + RoundUp(num_samples * num_classes * sizeof(T), kBlobBodyAlignSize)); + FOR_RANGE(int32_t, i, 0, num_samples) { + K label = target[i]; + if (label == ignore_index) { continue; } + label -= class_start; + if (label >= 0 && label < num_classes) { + const T w = weight ? -weight[label] : T(-1); + in_grad[i * num_classes + label] = out_grad[i] * w; + } + } + } +}; + +template struct NLLKernelUtil; +template struct NLLKernelUtil; +template struct NLLKernelUtil; +template struct NLLKernelUtil; + +} // namespace oneflow diff --git a/oneflow/user/kernels/nll_kernel_util.cu b/oneflow/user/kernels/nll_kernel_util.cu new file mode 100644 index 00000000000..5e01b7697d1 --- /dev/null +++ b/oneflow/user/kernels/nll_kernel_util.cu @@ -0,0 +1,92 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/nll_kernel_util.h" +#include "oneflow/core/cuda/atomic.cuh" + +namespace oneflow { + +namespace { + +template +__global__ void NLLForward(const int32_t num_samples, const K num_classes, const K class_start, + const K ignore_index, const T* input, const K* target, const T* weight, + T* out, T* out_weight) { + const T zero = GetZeroVal(); + const T one = GetOneVal(); + CUDA_1D_KERNEL_LOOP(i, num_samples) { + K label = target[i]; + T w = zero; + T y = zero; + if (label != ignore_index) { + label -= class_start; + if (label >= 0 && label < num_classes) { + w = weight ? weight[label] : one; + y = -(input[i * num_classes + label] * w); + } + } + out[i] = y; + out_weight[i] = w; + } +} + +template +__global__ void NLLBackward(const int32_t num_samples, const K num_classes, const K class_start, + const K ignore_index, const T* out_grad, const K* target, + const T* weight, T* in_grad) { + const T one = GetOneVal(); + const T zero = GetZeroVal(); + CUDA_1D_KERNEL_LOOP_T(K, i, num_samples * num_classes) { + const K n = i / num_classes; + const K idx = i - n * num_classes; + const K label = target[n]; + if (label != ignore_index && idx == label - class_start) { + in_grad[i] = out_grad[n] * (weight ? -weight[idx] : -one); + } else { + in_grad[i] = zero; + } + } +} + +} // namespace + +template +struct NLLKernelUtil { + static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes, + const K class_start, const K ignore_index, const T* input, const K* target, + const T* weight, T* out, T* out_weight) { + NLLForward<<As()->cuda_stream()>>>(num_samples, num_classes, + class_start, ignore_index, input, + target, weight, out, out_weight); + } + + static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes, + const K class_start, const K ignore_index, const T* out_grad, + const K* target, const T* weight, T* in_grad) { + NLLBackward<<As()->cuda_stream()>>>( + num_samples, num_classes, class_start, ignore_index, out_grad, target, weight, in_grad); + } +}; + +template struct NLLKernelUtil; +template struct NLLKernelUtil; +template struct NLLKernelUtil; +template struct NLLKernelUtil; +template struct NLLKernelUtil; +template struct NLLKernelUtil; + +} // namespace oneflow diff --git a/oneflow/user/kernels/nll_kernel_util.h b/oneflow/user/kernels/nll_kernel_util.h new file mode 100644 index 00000000000..25953d9b64f --- /dev/null +++ b/oneflow/user/kernels/nll_kernel_util.h @@ -0,0 +1,36 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_ +#define ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_ + +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template +struct NLLKernelUtil { + static void Forward(ep::Stream* stream, const int32_t num_samples, const K num_classes, + const K class_start, const K ignore_index, const T* input, const K* target, + const T* weight, T* out, T* out_weight); + + static void Backward(ep::Stream* stream, const int32_t num_samples, const K num_classes, + const K class_start, const K ignore_index, const T* out_grad, + const K* target, const T* weight, T* in_grad); +}; + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_NLL_KERNEL_UTIL_H_ diff --git a/oneflow/user/ops/nll_op.cpp b/oneflow/user/ops/nll_op.cpp index b170194aff4..1afffc2c16b 100644 --- a/oneflow/user/ops/nll_op.cpp +++ b/oneflow/user/ops/nll_op.cpp @@ -14,125 +14,183 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/user/ops/loss_op_util.h" #include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe NLLOp::InferDataType(user_op::InferContext* ctx) { + CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("target", 0))) + << ctx->op_name() << ": expected target being integer type"; -Maybe InferTensorDescFn(user_op::InferContext* ctx) { - const auto& input_desc = ctx->InputTensorDesc("input", 0); - const auto& target_desc = ctx->InputTensorDesc("target", 0); - CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); - CHECK_GE_OR_RETURN(input_desc.shape().NumAxes(), 2); - CHECK_EQ_OR_RETURN(target_desc.shape().NumAxes(), 1); - CHECK_EQ_OR_RETURN(input_desc.shape().At(0), target_desc.shape().At(0)); + auto input_dtype = ctx->InputDType("input", 0); if (ctx->has_input("weight", 0)) { - const auto& weight_desc = ctx->InputTensorDesc("weight", 0); - CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic()); - CHECK_EQ_OR_RETURN(weight_desc.shape(), Shape({input_desc.shape().At(1)})); + auto weight_dtype = ctx->InputDType("weight", 0); + CHECK_EQ_OR_RETURN(weight_dtype, input_dtype) << ctx->op_name() << ": expected weight dtype " + << input_dtype << ", but got " << weight_dtype; } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_is_dynamic() = input_desc.is_dynamic(); - *out_desc->mut_shape() = target_desc.shape(); - - user_op::TensorDesc* total_weight_desc = ctx->OutputTensorDesc("total_weight", 0); - *total_weight_desc->mut_is_dynamic() = input_desc.is_dynamic(); - *total_weight_desc->mut_shape() = Shape({}); - - return Maybe::Ok(); -} - -Maybe NllInferDataType(user_op::InferContext* ctx) { - const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); - CHECK_OR_RETURN(IsIndexDataType(target_desc.data_type())); - - *ctx->OutputDType("out", 0) = ctx->InputDType("input", 0); - *ctx->OutputDType("total_weight", 0) = ctx->InputDType("input", 0); + *ctx->OutputDType("output", 0) = input_dtype; + *ctx->OutputDType("out_weight", 0) = input_dtype; return Maybe::Ok(); } -Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { +/* static */ Maybe NLLOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); - const auto& total_weight_desc = ctx->InputTensorDesc("total_weight", 0); - const auto& dy_desc = ctx->InputTensorDesc("dy", 0); - CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); - CHECK_GE_OR_RETURN(input_desc.shape().NumAxes(), 2); - CHECK_EQ_OR_RETURN(target_desc.shape().NumAxes(), 1); - CHECK_EQ_OR_RETURN(input_desc.shape().At(0), target_desc.shape().At(0)); - CHECK_EQ_OR_RETURN(dy_desc.shape(), target_desc.shape()); - CHECK_EQ_OR_RETURN(total_weight_desc.shape(), Shape({})); + + const bool is_dynamic = input_desc.is_dynamic(); + CHECK_EQ_OR_RETURN(target_desc.is_dynamic(), is_dynamic) + << ctx->op_name() << ": expected the same dynamic with input and target"; + const int64_t K = input_desc.shape().NumAxes(); + CHECK_GE_OR_RETURN(K, 2) << ctx->op_name() << ": expected 2 or more dimensions for input"; + CHECK_EQ_OR_RETURN(target_desc.shape().NumAxes(), K - 1) + << ctx->op_name() << ": expected 1 less diemensions than input for target"; + const int64_t N = target_desc.shape().elem_cnt(); + const int64_t C = input_desc.shape().At(input_desc.shape().NumAxes() - 1); + CHECK_EQ_OR_RETURN(input_desc.shape().elem_cnt(), N * C) + << ctx->op_name() << ": expected input size " << input_desc.shape().ToString() + << " to match target size " << target_desc.shape().ToString(); + if (ctx->has_input("weight", 0)) { const auto& weight_desc = ctx->InputTensorDesc("weight", 0); - CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), input_desc.is_dynamic()); - CHECK_EQ_OR_RETURN(weight_desc.shape(), Shape({input_desc.shape().At(1)})); + CHECK_EQ_OR_RETURN(weight_desc.is_dynamic(), is_dynamic) + << ctx->op_name() << ": expected the same dynamic with input and weight"; + CHECK_EQ_OR_RETURN(weight_desc.shape().elem_cnt(), C) + << ctx->op_name() << ": expected weight size " << C << ", got " + << weight_desc.shape().ToString(); } - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_is_dynamic() = input_desc.is_dynamic(); - *dx_desc->mut_shape() = input_desc.shape(); + user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + *output_desc->mut_is_dynamic() = is_dynamic; + *output_desc->mut_shape() = Shape({N}); - return Maybe::Ok(); -} - -Maybe InferGradDataType(user_op::InferContext* ctx) { - const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); - CHECK_OR_RETURN(IsIndexDataType(target_desc.data_type())); - - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + user_op::TensorDesc* out_weight_desc = ctx->OutputTensorDesc("out_weight", 0); + *out_weight_desc->mut_is_dynamic() = is_dynamic; + *out_weight_desc->mut_shape() = Shape({N}); return Maybe::Ok(); } -} // namespace -/* static */ Maybe NllOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - return InferTensorDescFn(ctx); -} - -/*static*/ Maybe NllOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { - return InferLogicalTensorDesc(ctx); -} +/* static */ Maybe NLLOp::GetSbp(user_op::SbpContext* ctx) { + // split batch dim + auto builder1 = ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), 0) + .Split(user_op::OpArg("target", 0), 0) + .Split(user_op::OpArg("output", 0), 0) + .Split(user_op::OpArg("out_weight", 0), 0); + if (ctx->user_op_conf().has_input("weight", 0)) { + builder1.Broadcast(user_op::OpArg("weight", 0)); + } + builder1.Build(); + + // split class dim + const auto& shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + auto builder2 = ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), shape.NumAxes() - 1) + .Broadcast(user_op::OpArg("target", 0)) + .PartialSum(user_op::OpArg("output", 0)) + .PartialSum(user_op::OpArg("out_weight", 0)); + if (ctx->user_op_conf().has_input("weight", 0)) { + builder2.Split(user_op::OpArg("weight", 0), 0); + } + builder2.Build(); -/* static */ Maybe NllOp::GetSbp(user_op::SbpContext* ctx) { - return GenLossForwardDefaultGetSbpFn( - [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { - builder.PartialSum(user_op::OpArg("total_weight", 0)); - })(ctx); + return Maybe::Ok(); } -/* static */ Maybe NllOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, +/* static */ Maybe NLLOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); CHECK_OR_RETURN(target_modifier != nullptr); target_modifier->set_requires_grad(false); + if (conf.has_input("weight", 0)) { + auto* weight_modifier = GetInputArgModifierFn("weight", 0); + if (weight_modifier) { weight_modifier->set_requires_grad(false); } + } return Maybe::Ok(); } -/* static */ Maybe NllOp::InferDataType(user_op::InferContext* ctx) { - return NllInferDataType(ctx); -} +/* static */ Maybe NLLGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("target", 0))) + << ctx->op_name() << ": expected target being integer type"; -/* static */ Maybe NllGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - return InferGradTensorDescFn(ctx); -} + auto input_dtype = ctx->InputDType("input", 0); + CHECK_EQ_OR_RETURN(ctx->InputDType("out_grad", 0), input_dtype) + << ctx->op_name() << ": expected out_grad dtype " << input_dtype << ", got " + << ctx->InputDType("out_grad", 0); + + if (ctx->has_input("weight", 0)) { + CHECK_EQ_OR_RETURN(ctx->InputDType("weight", 0), input_dtype) + << ctx->op_name() << ": expected weight dtype " << input_dtype << ", got " + << ctx->InputDType("weight", 0); + } + + *ctx->OutputDType("in_grad", 0) = input_dtype; -/*static*/ Maybe NllGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { - return InferLogicalTensorDesc(ctx); + return Maybe::Ok(); } -/* static */ Maybe NllGradOp::GetSbp(user_op::SbpContext* ctx) { - return GenLossBackwardDefaultGetSbpFn( - [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { - builder.PartialSum(user_op::OpArg("total_weight", 0)); - })(ctx); +/* static */ Maybe NLLGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const auto& input_desc = ctx->InputTensorDesc("input", 0); + const auto& target_desc = ctx->InputTensorDesc("target", 0); + const auto& out_grad_desc = ctx->InputTensorDesc("out_grad", 0); + + bool is_dynamic = input_desc.is_dynamic(); + CHECK_EQ_OR_RETURN(target_desc.is_dynamic(), is_dynamic) + << ctx->op_name() << ": expected target dynamic " << is_dynamic; + CHECK_EQ_OR_RETURN(out_grad_desc.is_dynamic(), is_dynamic) + << ctx->op_name() << ": expected out_grad dynamic " << is_dynamic; + + const int64_t N = target_desc.shape().elem_cnt(); + CHECK_EQ_OR_RETURN(out_grad_desc.shape().elem_cnt(), N) + << ctx->op_name() << ": expected out_grad size " << N << ", got " + << out_grad_desc.shape().ToString(); + + const int64_t C = input_desc.shape().At(input_desc.shape().NumAxes() - 1); + CHECK_EQ_OR_RETURN(input_desc.shape().elem_cnt(), N * C) + << ctx->op_name() << ": expected input size " << N << ", got " + << input_desc.shape().ToString(); + + if (ctx->has_input("weight", 0)) { + const auto& weight_desc = ctx->InputTensorDesc("weight", 0); + CHECK_EQ_OR_RETURN(weight_desc.shape().elem_cnt(), C) + << ctx->op_name() << ": expected weight size " << C << ", got " + << weight_desc.shape().ToString(); + } + + user_op::TensorDesc* in_grad_desc = ctx->OutputTensorDesc("in_grad", 0); + *in_grad_desc->mut_is_dynamic() = is_dynamic; + *in_grad_desc->mut_shape() = input_desc.shape(); + + return Maybe::Ok(); } -/* static */ Maybe NllGradOp::InferDataType(user_op::InferContext* ctx) { - return InferGradDataType(ctx); +/* static */ Maybe NLLGradOp::GetSbp(user_op::SbpContext* ctx) { + // split batch dim + auto builder1 = ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), 0) + .Split(user_op::OpArg("target", 0), 0) + .Split(user_op::OpArg("out_grad", 0), 0) + .Split(user_op::OpArg("in_grad", 0), 0); + if (ctx->user_op_conf().has_input("weight", 0)) { + builder1.Broadcast(user_op::OpArg("weight", 0)); + } + builder1.Build(); + + // split class dim + const auto& shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + auto builder2 = ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), shape.NumAxes() - 1) + .Broadcast(user_op::OpArg("target", 0)) + .Broadcast(user_op::OpArg("out_grad", 0)) + .Split(user_op::OpArg("in_grad", 0), shape.NumAxes() - 1); + if (ctx->user_op_conf().has_input("weight", 0)) { + builder2.Split(user_op::OpArg("weight", 0), 0); + } + builder2.Build(); + + return Maybe::Ok(); } REGISTER_USER_OP_GRAD("nll").SetGenBackwardOpConfFn( @@ -142,15 +200,14 @@ REGISTER_USER_OP_GRAD("nll").SetGenBackwardOpConfFn( builder.Op("nll_grad") .Input("input", op.input("input", 0)) .Input("target", op.input("target", 0)) - .Input("total_weight", op.output("total_weight", 0)) - .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) - .Output("dx") + .Input("out_grad", op.GetGradTensorWithOpOutput("output", 0)) + .Output("in_grad") .Attr("ignore_index", op.attr("ignore_index")); if (op.user_op_conf().has_input("weight", 0)) { builder.Input("weight", op.input("weight", 0)); } - user_op::UserOpConfWrapper grad_op = builder.Build(); - op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "input", 0); + auto grad_op = builder.Build(); + op.BindGradTensorWithOpInput(grad_op.output("in_grad", 0), "input", 0); AddOp(grad_op); } return Maybe::Ok(); diff --git a/python/oneflow/nn/modules/loss.py b/python/oneflow/nn/modules/loss.py index 1a0310b3f78..a03087cf8fb 100644 --- a/python/oneflow/nn/modules/loss.py +++ b/python/oneflow/nn/modules/loss.py @@ -33,7 +33,7 @@ def __init__( self, weight: Optional[Tensor] = None, reduction: str = "mean" ) -> None: super(_WeightedLoss, self).__init__(reduction=reduction) - self.weight = weight + self.register_buffer("weight", weight) class L1Loss(_Loss): diff --git a/python/oneflow/test/modules/test_nll_loss.py b/python/oneflow/test/modules/test_nll_loss.py new file mode 100644 index 00000000000..301c3bc901a --- /dev/null +++ b/python/oneflow/test/modules/test_nll_loss.py @@ -0,0 +1,134 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import numpy as np +import unittest + +import oneflow as flow +import oneflow.unittest + +from oneflow.test_utils.automated_test_util import * + + +@autotest(n=1) +def _test_nll_loss( + test_case, has_weight=False, split_batch_dim=False, split_class_dim=False +): + N = random(1, 4) * 2 + C = random(1, 10) * 2 + ndim = random(2, 5).to(int).value() + dims = [random(2, 10) for i in range(ndim - 2)] + input_dims = [N, C] + dims + target_dims = [N] + dims + input = random_tensor(ndim, *input_dims) + target = random_tensor( + ndim - 1, *target_dims, low=0, high=C, dtype=int, requires_grad=False + ) + weight = None + if has_weight: + weight = random_tensor(1, C, requires_grad=False) + + device = random_device().value() + if not split_class_dim and not split_batch_dim: + input = input.to(device) + target = target.to(device) + if has_weight: + weight = weight.to(device) + else: + rank = flow.env.get_rank() + world_size = flow.env.get_world_size() + assert world_size % 2 == 0 + ranks = np.array(range(world_size)) + + if split_batch_dim and split_class_dim: + placement = flow.placement(device, ranks.reshape((ranks.size // 2, 2))) + input_sbp = [flow.sbp.split(0), flow.sbp.split(1)] + target_sbp = [flow.sbp.split(0), flow.sbp.broadcast()] + weight_sbp = [flow.sbp.broadcast(), flow.sbp.split(0)] + elif split_batch_dim: + placement = flow.placement(device, ranks) + input_sbp = flow.sbp.split(0) + target_sbp = flow.sbp.split(0) + weight_sbp = flow.sbp.broadcast() + else: + placement = flow.placement(device, ranks) + input_sbp = flow.sbp.split(1) + target_sbp = flow.sbp.broadcast() + weight_sbp = flow.sbp.split(0) + + input = input.to_global(placement=placement, sbp=input_sbp) + target = target.to_global(placement=placement, sbp=target_sbp) + # print( + # f"**[{rank}] input: {input.oneflow.shape} {input.oneflow.placement} {input.oneflow.sbp}" + # ) + # print( + # f"**[{rank}] target: {target.oneflow.shape} {target.oneflow.placement} {target.oneflow.sbp}" + # ) + if has_weight: + # print(f"**[{rank}] weight: {weight.oneflow.numpy()}") + weight = weight.to_global(placement=placement, sbp=weight_sbp) + + reduction = oneof("none", "sum", "mean") + if has_weight: + nll = torch.nn.NLLLoss(weight=weight, reduction=reduction) + else: + nll = torch.nn.NLLLoss(reduction=reduction) + return nll(input, target) + + +@flow.unittest.skip_unless_1n1d() +class NLLLossTestCase(flow.unittest.TestCase): + def test_local(test_case): + _test_nll_loss(test_case) + + def test_weighted(test_case): + _test_nll_loss(test_case, has_weight=True) + + +@flow.unittest.skip_unless_1n2d() +class ParallelNLLLossTestCase(flow.unittest.TestCase): + @globaltest + def test_data_parallel(test_case): + _test_nll_loss(test_case, split_batch_dim=True) + + @globaltest + def test_data_parallel_weighted(test_case): + _test_nll_loss(test_case, has_weight=True, split_batch_dim=True) + + @globaltest + def test_model_parallel(test_case): + _test_nll_loss(test_case, split_class_dim=True) + + @globaltest + def test_model_parallel_weighted(test_case): + _test_nll_loss(test_case, has_weight=True, split_class_dim=True) + + +@flow.unittest.skip_unless_1n4d() +class TowDParallelNLLLossTestCase(flow.unittest.TestCase): + @globaltest + def test_2d_parallel(test_case): + _test_nll_loss(test_case, split_batch_dim=True, split_class_dim=True) + + @globaltest + def test_2d_parallel_weighted(test_case): + _test_nll_loss( + test_case, has_weight=True, split_batch_dim=True, split_class_dim=True + ) + + +if __name__ == "__main__": + unittest.main()