Skip to content

Commit

Permalink
LogicalSliceAssign support full slice sbp (#8344)
Browse files Browse the repository at this point in the history
* feat(SliceOp): slice ops support 2d sbp

* fix(SliceOp): fix [B, P] 2d sbp bug

* refine error message

* fix bug in parallel_num == 1

* add comment

* add warning and format

* add NOLINT for boxing check

* feat(LogicalSliceOps): support all nd_sbp

* feat(LogicalSlice): support nd_sbp

* add error message

* fix(AutoTest): fix auto_test bug in module.parameter pass

* auto format by CI

* fix(LogicalSliceAssign): skip test when 1n1d

* fix SliceParams memset error

* remove memset

* add CHECK_JUST

* fix(*): make sure split_axis >= 0 or equal to SPLIT_AXIS_FOR_NON_SPLIT

* remove memset

* fix spilit_info.axis bug

* feat(LogicalSliceOps): support grad

* add logical_slice gradient_funcs

* feat(LogicalSliceAssign): LogicalSliceAssign support full slice sbp

* auto format by CI

* test(LogicalSlice): fix logical_slice dims

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Houjiang Chen <chenhoujiangcug@gmail.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
4 people authored Jun 9, 2022
1 parent 2e17cc3 commit 469f72d
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 68 deletions.
84 changes: 51 additions & 33 deletions oneflow/user/kernels/slice_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,30 +329,6 @@ DEFINE_STATIC_SWITCH_FUNC(
));
#undef MAKE_WRITE_SLICE_SWITCH_ENTRY

std::shared_ptr<user_op::OpKernelCache> CreateSliceCache(user_op::KernelCacheContext* ctx,
const std::string& large_tensor_name) {
SliceContext slice_ctx;
if (ctx->parallel_ctx().parallel_num() == 1) {
// split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split'
CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0));
} else {
const NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(large_tensor_name, 0);
const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();
const Shape& logical_shape =
ctx->LogicalTensorDesc4ArgNameAndIndex(large_tensor_name, 0)->shape();
const int64_t parallel_id = ctx->parallel_ctx().parallel_id();
const TensorSliceView& slice_view =
GetTensorSliceView4ParallelId(parallel_hierarchy, in_nd_sbp, logical_shape, parallel_id);
for (int i = 0; i < logical_shape.NumAxes(); ++i) {
const Range& range = slice_view.At(i);
if (range.begin() != 0 || range.end() != logical_shape.At(i)) {
CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i)));
}
}
}
return std::make_shared<OpKernelCacheWrapper<SliceContext>>(slice_ctx);
}

template<typename T>
class LogicalSliceKernel final : public user_op::OpKernel {
public:
Expand All @@ -361,7 +337,25 @@ class LogicalSliceKernel final : public user_op::OpKernel {

std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(
user_op::KernelCacheContext* ctx) const override {
return CreateSliceCache(ctx, "x");
SliceContext slice_ctx;
if (ctx->parallel_ctx().parallel_num() == 1) {
// split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split'
CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0));
} else {
const NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("x", 0);
const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();
const Shape& logical_shape = ctx->LogicalTensorDesc4ArgNameAndIndex("x", 0)->shape();
const int64_t parallel_id = ctx->parallel_ctx().parallel_id();
const TensorSliceView& slice_view =
GetTensorSliceView4ParallelId(parallel_hierarchy, in_nd_sbp, logical_shape, parallel_id);
for (int i = 0; i < logical_shape.NumAxes(); ++i) {
const Range& range = slice_view.At(i);
if (range.begin() != 0 || range.end() != logical_shape.At(i)) {
CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i)));
}
}
}
return std::make_shared<OpKernelCacheWrapper<SliceContext>>(slice_ctx);
}

private:
Expand All @@ -388,15 +382,39 @@ class LogicalSliceAssignKernel final : public user_op::OpKernel {

std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(
user_op::KernelCacheContext* ctx) const override {
if (ctx->parallel_ctx().parallel_num() > 1) {
const NdSbp& value_nd_sbp = ctx->NdSbp4ArgNameAndIndex("value", 0);
CHECK(std::all_of(value_nd_sbp.sbp_parallel().begin(), value_nd_sbp.sbp_parallel().end(),
[](const SbpParallel& sbp) {
return sbp.has_partial_sum_parallel() || sbp.has_broadcast_parallel();
}))
<< "value's sbp must be broadcast or partial_sum";
SliceContext slice_ctx;
if (ctx->parallel_ctx().parallel_num() == 1) {
// split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split'
CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0));
} else {
const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();
NdSbp ref_nd_sbp = ctx->NdSbp4ArgNameAndIndex("ref", 0);
{
const NdSbp value_nd_sbp = ctx->NdSbp4ArgNameAndIndex("value", 0);
// If ref and value both split in the same axis(full slice),
// we can consider the physical tensor is broadcast in this axis.
for (int i = 0; i < parallel_hierarchy.NumAxes(); ++i) {
const SbpParallel& ref_sbp = ref_nd_sbp.sbp_parallel(i);
const SbpParallel& value_sbp = value_nd_sbp.sbp_parallel(i);
if (ref_sbp.has_split_parallel() && value_sbp.has_split_parallel()) {
CHECK_EQ(ref_sbp.split_parallel().axis(), value_sbp.split_parallel().axis());
ref_nd_sbp.mutable_sbp_parallel(i)->clear_split_parallel();
ref_nd_sbp.mutable_sbp_parallel(i)->mutable_broadcast_parallel();
}
}
}
const Shape& logical_shape = ctx->LogicalTensorDesc4ArgNameAndIndex("ref", 0)->shape();
const int64_t parallel_id = ctx->parallel_ctx().parallel_id();
const TensorSliceView& slice_view =
GetTensorSliceView4ParallelId(parallel_hierarchy, ref_nd_sbp, logical_shape, parallel_id);
for (int i = 0; i < logical_shape.NumAxes(); ++i) {
const Range& range = slice_view.At(i);
if (range.begin() != 0 || range.end() != logical_shape.At(i)) {
CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i)));
}
}
}
return CreateSliceCache(ctx, "ref");
return std::make_shared<OpKernelCacheWrapper<SliceContext>>(slice_ctx);
}

private:
Expand Down
13 changes: 11 additions & 2 deletions oneflow/user/ops/slice_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,21 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) {
}

/*static*/ Maybe<void> LogicalSliceAssignOp::GetSbp(user_op::SbpContext* ctx) {
const user_op::TensorDesc& ref_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0);
FOR_RANGE(int64_t, axis, 0, ref_desc.shape().NumAxes()) {
const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0).shape();
const int64_t ndim = x_shape.NumAxes();
const auto& start_vec = ctx->Attr<std::vector<int64_t>>("start");
const auto& stop_vec = ctx->Attr<std::vector<int64_t>>("stop");
const auto& step_vec = ctx->Attr<std::vector<int64_t>>("step");
FOR_RANGE(int64_t, axis, 0, ndim) {
ctx->NewBuilder()
.Split(user_op::OpArg("ref", 0), axis)
.Broadcast(user_op::OpArg("value", 0))
.Split(user_op::OpArg("y", 0), axis)
.Build();
// FullSlice support S+S->S
if (IsFullSlice(start_vec[axis], stop_vec[axis], step_vec[axis], x_shape.At(axis))) {
ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();
}
}
ctx->NewBuilder()
.PartialSum(user_op::OpArg("ref", 0))
Expand Down Expand Up @@ -260,6 +268,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) {
ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();
return Maybe<void>::Ok();
}

/*static*/ Maybe<void> SliceUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
const auto& x_desc = ctx->InputTensorDesc("x", 0);
const int64_t ndim = x_desc.shape().NumAxes();
Expand Down
12 changes: 6 additions & 6 deletions python/oneflow/test/modules/test_consistent_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,18 @@ def _test_logical_slice_with_bool(test_case, placement, sbp):


def _test_logical_slice_with_grad(test_case, placement, sbp):
x = random_tensor(2, 4, 4, requires_grad=True).oneflow
x = random_tensor(2, 8, 16, requires_grad=True).oneflow
x_numpy = x.detach().cpu().numpy()

class LogicalSliceWithGrad(flow.nn.Module):
def __init__(self):
super().__init__()
self.input_grad = flow.nn.Parameter(flow.zeros(4, 4))
self.input_grad = flow.nn.Parameter(flow.zeros(8, 16))

def forward(self, input):
x = input + self.input_grad
x = x.to_global(placement, sbp)
return x[:, :2]
return x[:, :8]

logical_slice_with_grad = LogicalSliceWithGrad().to_global(
placement, [flow.sbp.broadcast,] * len(sbp)
Expand All @@ -154,10 +154,10 @@ def build(self, x):
y = graph(input)

# output
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[:, :2]))
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[:, :8]))
# input_grad
x_grad_np = np.zeros((4, 4))
x_grad_np[:, :2] = 1
x_grad_np = np.zeros((8, 16))
x_grad_np[:, :8] = 1
test_case.assertTrue(
np.array_equal(-graph.module.input_grad.origin.numpy(), x_grad_np)
)
Expand Down
75 changes: 48 additions & 27 deletions python/oneflow/test/modules/test_consistent_slice_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,49 @@


def _test_logical_slice_assign(test_case, placement, sbp):
input = random_tensor(2, 4, 4, requires_grad=True).oneflow
x_numpy = input.detach().cpu().numpy()

input = random_tensor(2, 8, 16, requires_grad=True).oneflow
value = random_tensor(2, 8, 8, requires_grad=True).oneflow
x = (input + 0).to_global(
placement=placement, sbp=sbp
) # add 0 to change to non-leaf tensor
x[:, :2] = 3
y = value.to_global(placement, sbp=sbp)
x[:, :8] = y

ref_np = input.detach().cpu().numpy()
value_np = value.detach().cpu().numpy()

# forward
x_numpy[:, :2] = 3
ref_np[:, :8] = value_np
test_case.assertTrue(x.sbp == sbp)
test_case.assertTrue(np.array_equal(x.numpy(), x_numpy))
test_case.assertTrue(np.array_equal(x.numpy(), ref_np))

# backward
x.sum().backward()
input_grad_np = np.ones((4, 4))
input_grad_np[:, :2] = 0
test_case.assertTrue(np.array_equal(input.grad.numpy(), input_grad_np))
# ref grad
ref_grad_np = np.ones((8, 16))
ref_grad_np[:, :8] = 0
test_case.assertTrue(np.array_equal(input.grad.numpy(), ref_grad_np))
# value grad
value_grad_np = np.ones((8, 8))
test_case.assertTrue(np.array_equal(value.grad.numpy(), value_grad_np))


def _test_graph_logical_slice_assign(test_case, placement, sbp):
x = random_tensor(2, 4, 4, requires_grad=True).oneflow
x_numpy = x.detach().cpu().numpy()
ref = random_tensor(2, 8, 16, requires_grad=True).oneflow
value = random_tensor(2, 8, 8, requires_grad=True).oneflow

class LogicalSliceAssignWithGrad(flow.nn.Module):
def __init__(self):
super().__init__()
self.input_grad = flow.nn.Parameter(flow.zeros(4, 4))
self.ref_grad = flow.nn.Parameter(flow.zeros(8, 16))
self.value_grad = flow.nn.Parameter(flow.zeros(8, 8))

def forward(self, input):
x = input + self.input_grad
def forward(self, ref, value):
x = ref + self.ref_grad
y = value + self.value_grad
x = x.to_global(placement, sbp)
x[:, :2] = 3
y = y.to_global(placement, sbp)
x[:, :8] = y
return x

logical_slice_assign_with_grad = LogicalSliceAssignWithGrad().to_global(
Expand All @@ -72,27 +82,38 @@ def __init__(self):
self.module = logical_slice_assign_with_grad
self.add_optimizer(of_sgd)

def build(self, x):
out = self.module(x)
def build(self, x, y):
out = self.module(x, y)
z = out.sum()
z.backward()
return out

graph = LogicalSliceAssignTrainGraph()

input = x.to_global(placement=placement, sbp=sbp)
y = graph(input)
x = ref.to_global(placement=placement, sbp=sbp)
y = value.to_global(placement=placement, sbp=sbp)
z = graph(x, y)

test_case.assertTrue(z.sbp == sbp)

ref_np = ref.detach().cpu().numpy()
value_np = value.detach().cpu().numpy()

test_case.assertTrue(y.sbp == sbp)
# forward
ref_np[:, :8] = value_np
test_case.assertTrue(np.array_equal(z.numpy(), ref_np))

# output
x_numpy[:, :2] = 3
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy))
# input_grad
x_grad_np = np.ones((4, 4))
x_grad_np[:, :2] = 0
# backward
# ref grad
ref_grad = np.ones((8, 16))
ref_grad[:, :8] = 0
test_case.assertTrue(
np.array_equal(-graph.module.ref_grad.origin.numpy(), ref_grad)
)
# value grad
value_grad = np.ones((8, 8))
test_case.assertTrue(
np.array_equal(-graph.module.input_grad.origin.numpy(), x_grad_np)
np.array_equal(-graph.module.value_grad.origin.numpy(), value_grad)
)


Expand Down

0 comments on commit 469f72d

Please sign in to comment.