Skip to content

Commit

Permalink
Feat logical slice ops grad (#8337)
Browse files Browse the repository at this point in the history
* feat(SliceOp): slice ops support 2d sbp

* fix(SliceOp): fix [B, P] 2d sbp bug

* refine error message

* fix bug in parallel_num == 1

* add comment

* add warning and format

* add NOLINT for boxing check

* feat(LogicalSliceOps): support all nd_sbp

* feat(LogicalSlice): support nd_sbp

* add error message

* fix(AutoTest): fix auto_test bug in module.parameter pass

* auto format by CI

* fix(LogicalSliceAssign): skip test when 1n1d

* fix SliceParams memset error

* remove memset

* add CHECK_JUST

* fix(*): make sure split_axis >= 0 or equal to SPLIT_AXIS_FOR_NON_SPLIT

* remove memset

* fix spilit_info.axis bug

* feat(LogicalSliceOps): support grad

* add logical_slice gradient_funcs

* modify as clang-tidy

* LogicalSlice ops grad use input nd_sbp

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Houjiang Chen <chenhoujiangcug@gmail.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
4 people authored Jun 6, 2022
1 parent 45cfcb5 commit 83ed0ba
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 22 deletions.
150 changes: 150 additions & 0 deletions oneflow/core/autograd/gradient_funcs/logical_slice.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

struct LogicalSliceCaptureState : public AutoGradCaptureState {
Shape like_shape;
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
Symbol<NdSbp> in_sbp;
};

class LogicalSlice : public OpExprGradFunction<LogicalSliceCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "LogicalSlice op_expr is null";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(LogicalSliceCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1) << "LogicalSlice input size must be 1";
CHECK_EQ_OR_RETURN(outputs.size(), 1) << "LogicalSlice output size must be 1";

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("start"));
ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stop"));
ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("step"));
ctx->like_shape = *(inputs[0]->shape());
ctx->in_sbp = JUST(inputs[0]->nd_sbp());
return Maybe<void>::Ok();
}

Maybe<void> Apply(const LogicalSliceCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(1);
std::shared_ptr<Tensor> zeros;
if (out_grads[0]->is_local()) {
zeros = JUST(functional::Constant(ctx->like_shape, 0, out_grads[0]->dtype(),
JUST(out_grads[0]->device())));
} else {
const auto& parallel_desc = JUST(out_grads[0]->parallel_desc());
zeros = JUST(functional::ConsistentConstant(ctx->like_shape, 0, out_grads[0]->dtype(),
parallel_desc, *JUST(GetSbpList(ctx->in_sbp))));
}
(*in_grads)[0] =
JUST(functional::LogicalSliceAssign(zeros, out_grads[0], ctx->start, ctx->stop, ctx->step));
return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

struct LogicalSliceAssignCaptureState : public AutoGradCaptureState {
bool requires_grad_ref = false;
bool requires_grad_value = false;
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
Shape value_shape; // used to calculate ref gradient
Symbol<NdSbp> value_sbp;
};

class LogicalSliceAssign : public OpExprGradFunction<LogicalSliceAssignCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "LogicalSliceAssign op_expr is null";

base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(LogicalSliceAssignCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2) << "LogicalSliceAssign input size must be 2";
CHECK_EQ_OR_RETURN(outputs.size(), 1) << "LogicalSliceAssign output size must be 1";
ctx->requires_grad_ref = inputs[0]->requires_grad();
ctx->requires_grad_value = inputs[1]->requires_grad();
if (!ctx->requires_grad_ref && !ctx->requires_grad_value) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("start"));
ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stop"));
ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("step"));

if (ctx->requires_grad_ref) {
ctx->value_shape = *(inputs[1]->shape());
ctx->value_sbp = JUST(inputs[1]->nd_sbp());
}
return Maybe<void>::Ok();
}

Maybe<void> Apply(const LogicalSliceAssignCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);

if (ctx->requires_grad_ref) {
std::shared_ptr<Tensor> zeros;
if (out_grads[0]->is_local()) {
zeros = JUST(functional::Constant(ctx->value_shape, 0, out_grads[0]->dtype(),
JUST(out_grads[0]->device())));
} else {
const auto& parallel_desc = JUST(out_grads[0]->parallel_desc());
zeros =
JUST(functional::ConsistentConstant(ctx->value_shape, 0, out_grads[0]->dtype(),
parallel_desc, *JUST(GetSbpList(ctx->value_sbp))));
}
(*in_grads)[0] = JUST(functional::LogicalSliceAssign(
JUST(functional::Identity(out_grads[0])), zeros, ctx->start, ctx->stop, ctx->step));
}
if (ctx->requires_grad_value) {
(*in_grads)[1] = JUST(functional::LogicalSlice(out_grads[0], ctx->start, ctx->stop, ctx->step,
/*enable_view_slice=*/false));
}
return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("logical_slice_assign", LogicalSliceAssign);
REGISTER_OP_EXPR_GRAD_FUNCTION("logical_slice", LogicalSlice);

} // namespace one
} // namespace oneflow
5 changes: 0 additions & 5 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2143,11 +2143,6 @@ class TensorSetItemFunctor {
JUST(LogicalSliceAssign(x, value_tensor, start, end, step));
}
} else {
if (requires_grad) {
return Error::RuntimeError() << "Backward is not support for consistent tensor setitem,"
"please use oneflow.no_grad() to disable autograd "
"currently. We will fix this problem soon.";
}
JUST(LogicalSliceAssign(x, value_tensor, start, end, step));
}
}
Expand Down
67 changes: 67 additions & 0 deletions oneflow/user/ops/slice_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,76 @@ Maybe<void> GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) {
return Maybe<void>::Ok();
}

Maybe<void> GenLogicalSliceAssignGradOp(user_op::BackwardOpConfContext* ctx) {
const std::string update_grad_op_name = ctx->FwOp().op_name() + "_value_grad";
ctx->DefineOp(update_grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("logical_slice")
.InputBind("x", ctx->FwOp().output_grad("y", 0))
.Attr("start", ctx->FwOp().attr<std::vector<int64_t>>("start"))
.Attr("stop", ctx->FwOp().attr<std::vector<int64_t>>("stop"))
.Attr("step", ctx->FwOp().attr<std::vector<int64_t>>("step"))
.Output("y")
.Build();
});
ctx->FwOp().InputGradBind(user_op::OpArg("value", 0), [&]() -> const std::string& {
return ctx->GetOp(update_grad_op_name).output("y", 0);
});

const std::string zero_grad_op_name = ctx->FwOp().op_name() + "_zero_grad";
ctx->DefineOp(zero_grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("zero_like")
.InputBind("like", ctx->FwOp().input("value", 0))
.Output("out")
.Build();
});
const std::string x_grad_op_name = ctx->FwOp().op_name() + "_x_grad";
ctx->DefineOp(x_grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("logical_slice_assign")
.InputBind("ref", ctx->FwOp().output_grad("y", 0))
.InputBind("value", ctx->GetOp(zero_grad_op_name).output("out", 0))
.Attr("start", ctx->FwOp().attr<std::vector<int64_t>>("start"))
.Attr("stop", ctx->FwOp().attr<std::vector<int64_t>>("stop"))
.Attr("step", ctx->FwOp().attr<std::vector<int64_t>>("step"))
.Output("y")
.Build();
});
ctx->FwOp().InputGradBind(user_op::OpArg("ref", 0), [&]() -> const std::string& {
return ctx->GetOp(x_grad_op_name).output("y", 0);
});
return Maybe<void>::Ok();
}

Maybe<void> GenLogicalSliceGradOp(user_op::BackwardOpConfContext* ctx) {
const std::string zero_grad_op_name = ctx->FwOp().op_name() + "_zero_grad";
ctx->DefineOp(zero_grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("zero_like")
.InputBind("like", ctx->FwOp().input("x", 0))
.Output("out")
.Build();
});
const std::string x_grad_op_name = ctx->FwOp().op_name() + "_x_grad";
ctx->DefineOp(x_grad_op_name, [&](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("logical_slice_assign")
.InputBind("ref", ctx->GetOp(zero_grad_op_name).output("out", 0))
.InputBind("value", ctx->FwOp().output_grad("y", 0))
.Attr("start", ctx->FwOp().attr<std::vector<int64_t>>("start"))
.Attr("stop", ctx->FwOp().attr<std::vector<int64_t>>("stop"))
.Attr("step", ctx->FwOp().attr<std::vector<int64_t>>("step"))
.Output("y")
.Build();
});
ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), [&]() -> const std::string& {
return ctx->GetOp(x_grad_op_name).output("y", 0);
});

return Maybe<void>::Ok();
}

} // namespace

REGISTER_USER_OP_GRAD("slice").SetGenBackwardOpConfFn(GenSliceGradOp);
REGISTER_USER_OP_GRAD("slice_update").SetBackwardOpConfGenFn(GenSliceUpdateGradOp);
REGISTER_USER_OP_GRAD("logical_slice_assign").SetBackwardOpConfGenFn(GenLogicalSliceAssignGradOp);
REGISTER_USER_OP_GRAD("logical_slice").SetBackwardOpConfGenFn(GenLogicalSliceGradOp);

} // namespace oneflow
61 changes: 58 additions & 3 deletions python/oneflow/test/modules/test_consistent_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,21 @@ def _test_slice_ellipsis_type(test_case, placement, sbp):


def _test_logical_slice(test_case, placement, sbp):
x = random_tensor(2, 8, 8, requires_grad=False).oneflow
x_numpy = x.detach().cpu().numpy()
input = random_tensor(2, 8, 8, requires_grad=True).oneflow
x_numpy = input.detach().cpu().numpy()

x = x.to_global(placement=placement, sbp=sbp)
x = input.to_global(placement=placement, sbp=sbp)
y = flow.logical_slice(x, slice_tup_list=[[0, 1, 1]])

# forward
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[0:1:1]))

# backward
y.sum().backward()
input_grad_np = np.zeros((8, 8))
input_grad_np[0:1:1, :] = 1
test_case.assertTrue(np.array_equal(input.grad.numpy(), input_grad_np))


def _test_logical_slice_with_bool(test_case, placement, sbp):
x = random_tensor(2, 8, 8).oneflow > 0.5
Expand All @@ -109,6 +116,53 @@ def _test_logical_slice_with_bool(test_case, placement, sbp):
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[0:1:1]))


def _test_logical_slice_with_grad(test_case, placement, sbp):
x = random_tensor(2, 4, 4, requires_grad=True).oneflow
x_numpy = x.detach().cpu().numpy()

class LogicalSliceWithGrad(flow.nn.Module):
def __init__(self):
super().__init__()
self.input_grad = flow.nn.Parameter(flow.zeros(4, 4))

def forward(self, input):
x = input + self.input_grad
x = x.to_global(placement, sbp)
return x[:, :2]

logical_slice_with_grad = LogicalSliceWithGrad().to_global(
placement, [flow.sbp.broadcast,] * len(sbp)
)

of_sgd = flow.optim.SGD(logical_slice_with_grad.parameters(), lr=1.0, momentum=0.0)

class LogicalSliceTrainGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.module = logical_slice_with_grad
self.add_optimizer(of_sgd)

def build(self, x):
out = self.module(x)
z = out.sum()
z.backward()
return out

graph = LogicalSliceTrainGraph()

input = x.to_global(placement=placement, sbp=sbp)
y = graph(input)

# output
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[:, :2]))
# input_grad
x_grad_np = np.zeros((4, 4))
x_grad_np[:, :2] = 1
test_case.assertTrue(
np.array_equal(-graph.module.input_grad.origin.numpy(), x_grad_np)
)


class TestSlice(flow.unittest.TestCase):
@globaltest
def test_slice(test_case):
Expand All @@ -128,6 +182,7 @@ def test_logical_slice(test_case):
for sbp in all_sbp(placement, max_dim=2):
_test_logical_slice(test_case, placement, sbp)
_test_logical_slice_with_bool(test_case, placement, sbp)
_test_logical_slice_with_grad(test_case, placement, sbp)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 83ed0ba

Please sign in to comment.