diff --git a/oneflow/core/autograd/gradient_funcs/logical_slice.cpp b/oneflow/core/autograd/gradient_funcs/logical_slice.cpp new file mode 100644 index 00000000000..ccc06f1cc77 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/logical_slice.cpp @@ -0,0 +1,150 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct LogicalSliceCaptureState : public AutoGradCaptureState { + Shape like_shape; + std::vector start; + std::vector stop; + std::vector step; + Symbol in_sbp; +}; + +class LogicalSlice : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "LogicalSlice op_expr is null"; + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(LogicalSliceCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1) << "LogicalSlice input size must be 1"; + CHECK_EQ_OR_RETURN(outputs.size(), 1) << "LogicalSlice output size must be 1"; + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->start = JUST(composed_attrs.GetAttr>("start")); + ctx->stop = JUST(composed_attrs.GetAttr>("stop")); + ctx->step = JUST(composed_attrs.GetAttr>("step")); + ctx->like_shape = *(inputs[0]->shape()); + ctx->in_sbp = JUST(inputs[0]->nd_sbp()); + return Maybe::Ok(); + } + + Maybe Apply(const LogicalSliceCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + in_grads->resize(1); + std::shared_ptr zeros; + if (out_grads[0]->is_local()) { + zeros = JUST(functional::Constant(ctx->like_shape, 0, out_grads[0]->dtype(), + JUST(out_grads[0]->device()))); + } else { + const auto& parallel_desc = JUST(out_grads[0]->parallel_desc()); + zeros = JUST(functional::ConsistentConstant(ctx->like_shape, 0, out_grads[0]->dtype(), + parallel_desc, *JUST(GetSbpList(ctx->in_sbp)))); + } + (*in_grads)[0] = + JUST(functional::LogicalSliceAssign(zeros, out_grads[0], ctx->start, ctx->stop, ctx->step)); + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +struct LogicalSliceAssignCaptureState : public AutoGradCaptureState { + bool requires_grad_ref = false; + bool requires_grad_value = false; + std::vector start; + std::vector stop; + std::vector step; + Shape value_shape; // used to calculate ref gradient + Symbol value_sbp; +}; + +class LogicalSliceAssign : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "LogicalSliceAssign op_expr is null"; + + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(LogicalSliceAssignCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 2) << "LogicalSliceAssign input size must be 2"; + CHECK_EQ_OR_RETURN(outputs.size(), 1) << "LogicalSliceAssign output size must be 1"; + ctx->requires_grad_ref = inputs[0]->requires_grad(); + ctx->requires_grad_value = inputs[1]->requires_grad(); + if (!ctx->requires_grad_ref && !ctx->requires_grad_value) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->start = JUST(composed_attrs.GetAttr>("start")); + ctx->stop = JUST(composed_attrs.GetAttr>("stop")); + ctx->step = JUST(composed_attrs.GetAttr>("step")); + + if (ctx->requires_grad_ref) { + ctx->value_shape = *(inputs[1]->shape()); + ctx->value_sbp = JUST(inputs[1]->nd_sbp()); + } + return Maybe::Ok(); + } + + Maybe Apply(const LogicalSliceAssignCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + in_grads->resize(2); + + if (ctx->requires_grad_ref) { + std::shared_ptr zeros; + if (out_grads[0]->is_local()) { + zeros = JUST(functional::Constant(ctx->value_shape, 0, out_grads[0]->dtype(), + JUST(out_grads[0]->device()))); + } else { + const auto& parallel_desc = JUST(out_grads[0]->parallel_desc()); + zeros = + JUST(functional::ConsistentConstant(ctx->value_shape, 0, out_grads[0]->dtype(), + parallel_desc, *JUST(GetSbpList(ctx->value_sbp)))); + } + (*in_grads)[0] = JUST(functional::LogicalSliceAssign( + JUST(functional::Identity(out_grads[0])), zeros, ctx->start, ctx->stop, ctx->step)); + } + if (ctx->requires_grad_value) { + (*in_grads)[1] = JUST(functional::LogicalSlice(out_grads[0], ctx->start, ctx->stop, ctx->step, + /*enable_view_slice=*/false)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("logical_slice_assign", LogicalSliceAssign); +REGISTER_OP_EXPR_GRAD_FUNCTION("logical_slice", LogicalSlice); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 9a64d5637e3..8c2f90ade5d 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -2143,11 +2143,6 @@ class TensorSetItemFunctor { JUST(LogicalSliceAssign(x, value_tensor, start, end, step)); } } else { - if (requires_grad) { - return Error::RuntimeError() << "Backward is not support for consistent tensor setitem," - "please use oneflow.no_grad() to disable autograd " - "currently. We will fix this problem soon."; - } JUST(LogicalSliceAssign(x, value_tensor, start, end, step)); } } diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index efda24c0e63..fac6d0ed57c 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -367,9 +367,76 @@ Maybe GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { return Maybe::Ok(); } +Maybe GenLogicalSliceAssignGradOp(user_op::BackwardOpConfContext* ctx) { + const std::string update_grad_op_name = ctx->FwOp().op_name() + "_value_grad"; + ctx->DefineOp(update_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("logical_slice") + .InputBind("x", ctx->FwOp().output_grad("y", 0)) + .Attr("start", ctx->FwOp().attr>("start")) + .Attr("stop", ctx->FwOp().attr>("stop")) + .Attr("step", ctx->FwOp().attr>("step")) + .Output("y") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("value", 0), [&]() -> const std::string& { + return ctx->GetOp(update_grad_op_name).output("y", 0); + }); + + const std::string zero_grad_op_name = ctx->FwOp().op_name() + "_zero_grad"; + ctx->DefineOp(zero_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("zero_like") + .InputBind("like", ctx->FwOp().input("value", 0)) + .Output("out") + .Build(); + }); + const std::string x_grad_op_name = ctx->FwOp().op_name() + "_x_grad"; + ctx->DefineOp(x_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("logical_slice_assign") + .InputBind("ref", ctx->FwOp().output_grad("y", 0)) + .InputBind("value", ctx->GetOp(zero_grad_op_name).output("out", 0)) + .Attr("start", ctx->FwOp().attr>("start")) + .Attr("stop", ctx->FwOp().attr>("stop")) + .Attr("step", ctx->FwOp().attr>("step")) + .Output("y") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("ref", 0), [&]() -> const std::string& { + return ctx->GetOp(x_grad_op_name).output("y", 0); + }); + return Maybe::Ok(); +} + +Maybe GenLogicalSliceGradOp(user_op::BackwardOpConfContext* ctx) { + const std::string zero_grad_op_name = ctx->FwOp().op_name() + "_zero_grad"; + ctx->DefineOp(zero_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("zero_like") + .InputBind("like", ctx->FwOp().input("x", 0)) + .Output("out") + .Build(); + }); + const std::string x_grad_op_name = ctx->FwOp().op_name() + "_x_grad"; + ctx->DefineOp(x_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("logical_slice_assign") + .InputBind("ref", ctx->GetOp(zero_grad_op_name).output("out", 0)) + .InputBind("value", ctx->FwOp().output_grad("y", 0)) + .Attr("start", ctx->FwOp().attr>("start")) + .Attr("stop", ctx->FwOp().attr>("stop")) + .Attr("step", ctx->FwOp().attr>("step")) + .Output("y") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), [&]() -> const std::string& { + return ctx->GetOp(x_grad_op_name).output("y", 0); + }); + + return Maybe::Ok(); +} + } // namespace REGISTER_USER_OP_GRAD("slice").SetGenBackwardOpConfFn(GenSliceGradOp); REGISTER_USER_OP_GRAD("slice_update").SetBackwardOpConfGenFn(GenSliceUpdateGradOp); +REGISTER_USER_OP_GRAD("logical_slice_assign").SetBackwardOpConfGenFn(GenLogicalSliceAssignGradOp); +REGISTER_USER_OP_GRAD("logical_slice").SetBackwardOpConfGenFn(GenLogicalSliceGradOp); } // namespace oneflow diff --git a/python/oneflow/test/modules/test_consistent_slice.py b/python/oneflow/test/modules/test_consistent_slice.py index c310cd667be..cc39410d200 100644 --- a/python/oneflow/test/modules/test_consistent_slice.py +++ b/python/oneflow/test/modules/test_consistent_slice.py @@ -90,14 +90,21 @@ def _test_slice_ellipsis_type(test_case, placement, sbp): def _test_logical_slice(test_case, placement, sbp): - x = random_tensor(2, 8, 8, requires_grad=False).oneflow - x_numpy = x.detach().cpu().numpy() + input = random_tensor(2, 8, 8, requires_grad=True).oneflow + x_numpy = input.detach().cpu().numpy() - x = x.to_global(placement=placement, sbp=sbp) + x = input.to_global(placement=placement, sbp=sbp) y = flow.logical_slice(x, slice_tup_list=[[0, 1, 1]]) + # forward test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[0:1:1])) + # backward + y.sum().backward() + input_grad_np = np.zeros((8, 8)) + input_grad_np[0:1:1, :] = 1 + test_case.assertTrue(np.array_equal(input.grad.numpy(), input_grad_np)) + def _test_logical_slice_with_bool(test_case, placement, sbp): x = random_tensor(2, 8, 8).oneflow > 0.5 @@ -109,6 +116,53 @@ def _test_logical_slice_with_bool(test_case, placement, sbp): test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[0:1:1])) +def _test_logical_slice_with_grad(test_case, placement, sbp): + x = random_tensor(2, 4, 4, requires_grad=True).oneflow + x_numpy = x.detach().cpu().numpy() + + class LogicalSliceWithGrad(flow.nn.Module): + def __init__(self): + super().__init__() + self.input_grad = flow.nn.Parameter(flow.zeros(4, 4)) + + def forward(self, input): + x = input + self.input_grad + x = x.to_global(placement, sbp) + return x[:, :2] + + logical_slice_with_grad = LogicalSliceWithGrad().to_global( + placement, [flow.sbp.broadcast,] * len(sbp) + ) + + of_sgd = flow.optim.SGD(logical_slice_with_grad.parameters(), lr=1.0, momentum=0.0) + + class LogicalSliceTrainGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.module = logical_slice_with_grad + self.add_optimizer(of_sgd) + + def build(self, x): + out = self.module(x) + z = out.sum() + z.backward() + return out + + graph = LogicalSliceTrainGraph() + + input = x.to_global(placement=placement, sbp=sbp) + y = graph(input) + + # output + test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[:, :2])) + # input_grad + x_grad_np = np.zeros((4, 4)) + x_grad_np[:, :2] = 1 + test_case.assertTrue( + np.array_equal(-graph.module.input_grad.origin.numpy(), x_grad_np) + ) + + class TestSlice(flow.unittest.TestCase): @globaltest def test_slice(test_case): @@ -128,6 +182,7 @@ def test_logical_slice(test_case): for sbp in all_sbp(placement, max_dim=2): _test_logical_slice(test_case, placement, sbp) _test_logical_slice_with_bool(test_case, placement, sbp) + _test_logical_slice_with_grad(test_case, placement, sbp) if __name__ == "__main__": diff --git a/python/oneflow/test/modules/test_consistent_slice_assign.py b/python/oneflow/test/modules/test_consistent_slice_assign.py index 2459b2f560d..b2088b6bdd2 100644 --- a/python/oneflow/test/modules/test_consistent_slice_assign.py +++ b/python/oneflow/test/modules/test_consistent_slice_assign.py @@ -23,44 +23,86 @@ def _test_logical_slice_assign(test_case, placement, sbp): - x = random_tensor(2, 4, 4, requires_grad=False).oneflow - x_numpy = x.detach().cpu().numpy() + input = random_tensor(2, 4, 4, requires_grad=True).oneflow + x_numpy = input.detach().cpu().numpy() - x = x.to_global(placement=placement, sbp=sbp) + x = (input + 0).to_global( + placement=placement, sbp=sbp + ) # add 0 to change to non-leaf tensor x[:, :2] = 3 - x_numpy[:, :2] = 3 + # forward + x_numpy[:, :2] = 3 test_case.assertTrue(x.sbp == sbp) test_case.assertTrue(np.array_equal(x.numpy(), x_numpy)) + # backward + x.sum().backward() + input_grad_np = np.ones((4, 4)) + input_grad_np[:, :2] = 0 + test_case.assertTrue(np.array_equal(input.grad.numpy(), input_grad_np)) + def _test_graph_logical_slice_assign(test_case, placement, sbp): - x = random_tensor(2, 4, 4, requires_grad=False).oneflow + x = random_tensor(2, 4, 4, requires_grad=True).oneflow x_numpy = x.detach().cpu().numpy() - @flow.nn.Graph.to_graph - def test_func(x): - x[:, :2] = 3 - return x + class LogicalSliceAssignWithGrad(flow.nn.Module): + def __init__(self): + super().__init__() + self.input_grad = flow.nn.Parameter(flow.zeros(4, 4)) - x = x.to_global(placement=placement, sbp=sbp) + def forward(self, input): + x = input + self.input_grad + x = x.to_global(placement, sbp) + x[:, :2] = 3 + return x - y = test_func(x) + logical_slice_assign_with_grad = LogicalSliceAssignWithGrad().to_global( + placement, [flow.sbp.broadcast,] * len(sbp) + ) - x_numpy[:, :2] = 3 + of_sgd = flow.optim.SGD( + logical_slice_assign_with_grad.parameters(), lr=1.0, momentum=0.0 + ) + + class LogicalSliceAssignTrainGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.module = logical_slice_assign_with_grad + self.add_optimizer(of_sgd) + + def build(self, x): + out = self.module(x) + z = out.sum() + z.backward() + return out + + graph = LogicalSliceAssignTrainGraph() + + input = x.to_global(placement=placement, sbp=sbp) + y = graph(input) test_case.assertTrue(y.sbp == sbp) + + # output + x_numpy[:, :2] = 3 test_case.assertTrue(np.array_equal(y.numpy(), x_numpy)) + # input_grad + x_grad_np = np.ones((4, 4)) + x_grad_np[:, :2] = 0 + test_case.assertTrue( + np.array_equal(-graph.module.input_grad.origin.numpy(), x_grad_np) + ) class TestGlobalLogicalSliceAssign(flow.unittest.TestCase): @globaltest def test_logical_slice_assign(test_case): for placement in all_placement(): - for sbp in all_sbp(placement, max_dim=2, except_split=True): + for sbp in all_sbp(placement, max_dim=2): if placement.ranks.size == 1: continue - # logical slice assign only support broadcast and partial_sum currently _test_logical_slice_assign(test_case, placement, sbp) _test_graph_logical_slice_assign(test_case, placement, sbp)