Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge slice and logical slice #8416

Merged
merged 36 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a111947
remove Slice, SliceUpdate, SliceGrad op
wyg1997 Jun 10, 2022
4ebe3dd
rename logical_slice to slice and logical_slice_assign to slice_update
wyg1997 Jun 13, 2022
677bd20
move gradient_func logical_slice.cpp to slice.cpp
wyg1997 Jun 13, 2022
c4ba67d
fix some bug and refine local test
wyg1997 Jun 13, 2022
6f7db6b
feat(SliceUpdate): support 0size tensor
wyg1997 Jun 13, 2022
337b006
test(Slice): refine consistent slice test
wyg1997 Jun 13, 2022
c47dd64
test(SliceUpdate): refine consistent slice_update test
wyg1997 Jun 13, 2022
1cdc1db
Merge branch 'master' into merge-slice_and_logical_slice
wyg1997 Jun 13, 2022
1ff7483
Merge branch 'master' into merge-slice_and_logical_slice
wyg1997 Jun 13, 2022
919a98a
not export slice_update's inplace parameter
wyg1997 Jun 13, 2022
a8a0e91
auto format by CI
oneflow-ci-bot Jun 13, 2022
d3f18f1
recovery slice_grad_op
wyg1997 Jun 14, 2022
d5cc188
Merge branch 'master' into merge-slice_and_logical_slice
wyg1997 Jun 14, 2022
b25f966
fix slice_view bug
wyg1997 Jun 14, 2022
01430b6
Merge branch 'merge-slice_and_logical_slice' of github.com:Oneflow-In…
wyg1997 Jun 14, 2022
9d841d0
add error message and attr judgement
wyg1997 Jun 14, 2022
b9ced2e
modified old test
wyg1997 Jun 14, 2022
a3a21f9
Merge branch 'master' into merge-slice_and_logical_slice
wyg1997 Jun 14, 2022
1cec093
auto format by CI
oneflow-ci-bot Jun 14, 2022
64503e0
Merge branch 'master' into merge-slice_and_logical_slice
wyg1997 Jun 15, 2022
15808d5
update test README
wyg1997 Jun 15, 2022
7175616
update tensor_string code
wyg1997 Jun 15, 2022
d66c9a7
fix test bug
wyg1997 Jun 15, 2022
3bc956f
auto format by CI
oneflow-ci-bot Jun 15, 2022
cf43403
fix(hsplit): hsplit functor bug
wyg1997 Jun 15, 2022
b4cefae
Merge branch 'merge-slice_and_logical_slice' of github.com:Oneflow-In…
wyg1997 Jun 15, 2022
e758d44
fix vsplit doc test bug
wyg1997 Jun 15, 2022
d7278d6
Merge remote-tracking branch 'origin/master' into merge-slice_and_log…
wyg1997 Jun 15, 2022
d58174a
refine
wyg1997 Jun 16, 2022
8df2054
Merge remote-tracking branch 'origin/master' into merge-slice_and_log…
wyg1997 Jun 16, 2022
11458de
Merge remote-tracking branch 'origin/master' into merge-slice_and_log…
wyg1997 Jun 17, 2022
fc0439a
fix test
wyg1997 Jun 17, 2022
dea6f9b
fix pin_memory bug
wyg1997 Jun 17, 2022
06fc444
Merge branch 'master' into merge-slice_and_logical_slice
wyg1997 Jun 18, 2022
941ac47
Merge branch 'master' into merge-slice_and_logical_slice
wyg1997 Jun 20, 2022
7de4b23
Merge branch 'master' into merge-slice_and_logical_slice
mergify[bot] Jun 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ oneflow
selu,
silu,
slice,
logical_slice,
slice_update,
hjchen2 marked this conversation as resolved.
Show resolved Hide resolved
softsign,
sort,
softplus,
Expand Down
150 changes: 0 additions & 150 deletions oneflow/core/autograd/gradient_funcs/logical_slice.cpp

This file was deleted.

72 changes: 47 additions & 25 deletions oneflow/core/autograd/gradient_funcs/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,50 @@ namespace oneflow {
namespace one {

struct SliceCaptureState : public AutoGradCaptureState {
bool requires_grad;
Shape like_shape;
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
Symbol<NdSbp> in_sbp;
};

class Slice : public OpExprGradFunction<SliceCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "Slice op_expr is null";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(SliceCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(inputs.size(), 1) << "Slice input size must be 1";
CHECK_EQ_OR_RETURN(outputs.size(), 1) << "Slice output size must be 1";

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("start"));
ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stop"));
ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("step"));
ctx->like_shape = *(inputs.at(0)->shape());
ctx->like_shape = *(inputs[0]->shape());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at 会有越界检查,为什么要换呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在系统内要么用 VectorAt,要么直接用 [] 来取,上面已经有维度检察了,这里不会越界,所以直接取下标就可以

if (inputs[0]->is_consistent()) { ctx->in_sbp = JUST(inputs[0]->nd_sbp()); }
return Maybe<void>::Ok();
}

Maybe<void> Apply(const SliceCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(1);
in_grads->at(0) = JUST(
functional::SliceGrad(out_grads.at(0), ctx->like_shape, ctx->start, ctx->stop, ctx->step));
std::shared_ptr<Tensor> zeros;
if (out_grads[0]->is_local()) {
zeros = JUST(functional::Constant(ctx->like_shape, 0, out_grads[0]->dtype(),
JUST(out_grads[0]->device())));
} else {
const auto& parallel_desc = JUST(out_grads[0]->parallel_desc());
zeros = JUST(functional::ConsistentConstant(ctx->like_shape, 0, out_grads[0]->dtype(),
parallel_desc, *JUST(GetSbpList(ctx->in_sbp))));
}
(*in_grads)[0] = JUST(functional::SliceUpdate(zeros, out_grads[0], ctx->start, ctx->stop,
ctx->step, /*inplace=*/false));
return Maybe<void>::Ok();
}

Expand All @@ -67,51 +75,65 @@ class Slice : public OpExprGradFunction<SliceCaptureState> {
};

struct SliceUpdateCaptureState : public AutoGradCaptureState {
bool requires_grad_x;
bool requires_grad_update;
bool requires_grad_ref = false;
bool requires_grad_value = false;
std::vector<int64_t> start;
std::vector<int64_t> stop;
std::vector<int64_t> step;
Shape value_shape; // used to calculate ref gradient
Symbol<NdSbp> value_sbp;
};

class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "SliceUpdate op_expr is null";

base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(SliceUpdateCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad_x = inputs.at(0)->requires_grad();
ctx->requires_grad_update = inputs.at(1)->requires_grad();
if (!ctx->requires_grad_x && !ctx->requires_grad_update) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(inputs.size(), 2) << "SliceUpdate input size must be 2";
CHECK_EQ_OR_RETURN(outputs.size(), 1) << "SliceUpdate output size must be 1";
ctx->requires_grad_ref = inputs[0]->requires_grad();
ctx->requires_grad_value = inputs[1]->requires_grad();
if (!ctx->requires_grad_ref && !ctx->requires_grad_value) { return Maybe<void>::Ok(); }

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->start = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("start"));
ctx->stop = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("stop"));
ctx->step = JUST(composed_attrs.GetAttr<std::vector<int64_t>>("step"));

if (ctx->requires_grad_x) { ctx->SaveTensorForBackward(inputs.at(1)); }
if (ctx->requires_grad_ref) {
ctx->value_shape = *(inputs[1]->shape());
if (inputs[1]->is_consistent()) { ctx->value_sbp = JUST(inputs[1]->nd_sbp()); }
}
return Maybe<void>::Ok();
}

Maybe<void> Apply(const SliceUpdateCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->resize(2);

if (ctx->requires_grad_x) {
const auto& update = ctx->SavedTensors().at(0);
const auto& temp = JUST(functional::ZerosLike(update));
(*in_grads)[0] = JUST(functional::SliceUpdate(out_grads[0], temp, ctx->start, ctx->stop,
ctx->step, /*inplace=*/false));
if (ctx->requires_grad_ref) {
std::shared_ptr<Tensor> zeros;
if (out_grads[0]->is_local()) {
zeros = JUST(functional::Constant(ctx->value_shape, 0, out_grads[0]->dtype(),
JUST(out_grads[0]->device())));
} else {
const auto& parallel_desc = JUST(out_grads[0]->parallel_desc());
zeros =
JUST(functional::ConsistentConstant(ctx->value_shape, 0, out_grads[0]->dtype(),
parallel_desc, *JUST(GetSbpList(ctx->value_sbp))));
}
(*in_grads)[0] =
JUST(functional::SliceUpdate(JUST(functional::Identity(out_grads[0])), zeros, ctx->start,
ctx->stop, ctx->step, /*inplace=*/false));
}
if (ctx->requires_grad_update) {
if (ctx->requires_grad_value) {
(*in_grads)[1] = JUST(functional::Slice(out_grads[0], ctx->start, ctx->stop, ctx->step,
/*enable_view_slice=*/false));
}
Expand All @@ -122,8 +144,8 @@ class SliceUpdate : public OpExprGradFunction<SliceUpdateCaptureState> {
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("slice", Slice);
REGISTER_OP_EXPR_GRAD_FUNCTION("slice_update", SliceUpdate);
REGISTER_OP_EXPR_GRAD_FUNCTION("slice", Slice);

} // namespace one
} // namespace oneflow
4 changes: 2 additions & 2 deletions oneflow/core/boxing/symmetric_b_to_s_boxing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ Maybe<one::Tensor> SymmetricB2S(const std::shared_ptr<one::Tensor>& tensor, Symb
start.emplace_back(range.begin());
stop.emplace_back(range.end());
}
local_tensor =
JUST(one::functional::Slice(local_tensor, start, stop, step, /*enable_view_slice=*/false));
local_tensor = JUST(one::functional::Slice(local_tensor, start, stop, step,
/*enable_view_slice=*/false));
}

return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(),
Expand Down
17 changes: 14 additions & 3 deletions oneflow/core/framework/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,26 @@ Maybe<Tensor> Slice(const std::shared_ptr<Tensor>& input, const std::vector<int6
}

auto output = JUST(BasicView(input, Shape(target_dims), Stride(target_strides), storage_offset));
Symbol<NdSbp> in_nd_sbp;
if (input->is_consistent()) { in_nd_sbp = JUST(input->nd_sbp()); }

if (autograd::GradMode::is_enabled() && input->requires_grad()) {
auto backward_fn = std::make_shared<BackwardFunction>();
backward_fn->body = [=](const TensorTuple& out_grads, TensorTuple* in_grads,
bool create_graph) -> Maybe<void> {
autograd::AutoGradMode mode(create_graph);
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
(*in_grads)[0] = JUST(functional::SliceGrad(
JUST(VectorAt(out_grads, 0)), Shape(input->shape()->dim_vec()), starts, ends, steps));
std::shared_ptr<Tensor> zeros;
if (out_grads[0]->is_local()) {
zeros = JUST(
functional::Constant(*shape, 0, out_grads[0]->dtype(), JUST(out_grads[0]->device())));
} else {
const auto& parallel_desc = JUST(out_grads[0]->parallel_desc());
zeros = JUST(functional::ConsistentConstant(*shape, 0, out_grads[0]->dtype(), parallel_desc,
*JUST(GetSbpList(in_nd_sbp))));
}
(*in_grads)[0] = JUST(
functional::SliceUpdate(zeros, out_grads[0], starts, ends, steps, /*inplace=*/false));
return Maybe<void>::Ok();
};
backward_fn->status = []() { return true; };
Expand Down
20 changes: 4 additions & 16 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1297,14 +1297,6 @@
signature: "Tensor (Tensor x, Int64 start, Int64 end) => SliceView1dContiguous"
bind_python: True

- name: "slice"
signature: "Tensor (Tensor x, Int64List start, Int64List stop, Int64List step, Bool enable_view_slice=None) => Slice"
bind_python: True

- name: "slice_grad"
signature: "Tensor (Tensor dy, Shape like, Int64List start, Int64List stop, Int64List step) => SliceGrad"
bind_python: False

- name: "narrow"
signature: "Tensor (Tensor input, Int64 dim, Int64 start, Int64 length) => Narrow"
bind_python: True
Expand All @@ -1313,16 +1305,12 @@
signature: "Tensor (Tensor dy, Tensor like, Int64 dim, Int64 start, Int64 length) => NarrowGrad"
bind_python: False

- name: "slice_update"
signature: "Tensor (Tensor x, Tensor update, Int64List start, Int64List stop, Int64List step, *, Bool inplace=False) => SliceUpdate"
bind_python: True

- name: "logical_slice"
signature: "Tensor (Tensor x, Int64List start, Int64List stop, Int64List step, Bool enable_view_slice=None) => LogicalSlice"
- name: "slice"
signature: "Tensor (Tensor x, Int64List start, Int64List stop, Int64List step, Bool enable_view_slice=None) => Slice"
bind_python: True

- name: "logical_slice_assign"
signature: "Tensor (Tensor ref, Tensor value, Int64List start, Int64List stop, Int64List step) => LogicalSliceAssign"
- name: "slice_update"
signature: "Tensor (Tensor ref, Tensor value, Int64List start, Int64List stop, Int64List step, Bool inplace=False) => SliceUpdate"
bind_python: True

- name: "copy"
Expand Down
Loading