Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add view weight norm where zeropad2d global test #7886

Merged
merged 65 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
9a40937
fix_unfold_tensor_sbp_and_add_global_test
clackhan Mar 23, 2022
8ff6ec6
refine
clackhan Mar 23, 2022
d38ba32
add_var_upsample_global_test
clackhan Mar 23, 2022
8ebbb96
add_view_weight_norm_where_zeropad2d_global_test
clackhan Mar 24, 2022
ef41138
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Mar 24, 2022
984e07f
del code not in this branch
clackhan Mar 24, 2022
6a45f0b
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
clackhan Mar 24, 2022
0a5e7e2
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Mar 28, 2022
8240592
Merge branch 'add_view_weight_norm_where_zeropad2d_global_test' of ht…
clackhan Mar 28, 2022
91e03c7
fix where infer shape and sbp bug
clackhan Mar 28, 2022
b19fda5
del CheckBroadcastable
clackhan Mar 28, 2022
ad91472
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
clackhan Mar 29, 2022
9415e21
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 29, 2022
2f0f699
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 29, 2022
07b49b0
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 29, 2022
a639e5e
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
clackhan Mar 29, 2022
9d065ba
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 29, 2022
3cd4215
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 29, 2022
f4d975b
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 29, 2022
eaedcbb
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 30, 2022
380edff
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 30, 2022
2a629d7
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 30, 2022
4615260
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 30, 2022
20900e2
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 30, 2022
b984a24
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 30, 2022
04bce37
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 31, 2022
4c2d731
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 31, 2022
6380dd8
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 31, 2022
03f7e48
Update test_consistent_view.py
clackhan Mar 31, 2022
427ad76
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 31, 2022
41624d9
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 31, 2022
f61f8ab
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 31, 2022
4338124
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Mar 31, 2022
bcc446c
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
clackhan Apr 1, 2022
f4e6029
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 1, 2022
9a4f396
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 1, 2022
b2a53d0
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 1, 2022
bc88cbf
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 2, 2022
9b91072
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 2, 2022
0c58fa2
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 2, 2022
771b0a7
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 2, 2022
c1c8db1
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 2, 2022
51df93a
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 2, 2022
3c5f66e
fix where test error
clackhan Apr 6, 2022
0450ab4
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
clackhan Apr 6, 2022
3231d36
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 6, 2022
f3d3e41
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 6, 2022
efcbfda
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 6, 2022
6dc72e8
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 7, 2022
407d4b0
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 7, 2022
94d0193
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 7, 2022
bff4b02
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 7, 2022
b4e0b98
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 7, 2022
ac5627b
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 8, 2022
82fbe6e
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 8, 2022
49b1410
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 8, 2022
3110802
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 8, 2022
b825200
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 9, 2022
d12c01b
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 9, 2022
98f6f2a
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 9, 2022
7cbb7d5
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 9, 2022
faa6fb1
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
clackhan Apr 11, 2022
a2bdd5a
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 11, 2022
158e94f
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 11, 2022
2764db4
Merge branch 'master' into add_view_weight_norm_where_zeropad2d_globa…
mergify[bot] Apr 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 101 additions & 133 deletions oneflow/user/ops/where_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,89 @@ namespace oneflow {

namespace {

Maybe<void> CheckBroadcastable(const Shape& a_shape, const Shape& b_shape) {
Shape broadcast_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes()));
Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), broadcast_shape.NumAxes());
Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), broadcast_shape.NumAxes());
FOR_RANGE(int64_t, i, 0, broadcast_shape.NumAxes()) {
CHECK_OR_RETURN(a_extend_shape.At(i) == 1 || b_extend_shape.At(i) == 1
|| a_extend_shape.At(i) == b_extend_shape.At(i))
<< Error::RuntimeError() << "The size of tensor a (" << a_extend_shape.At(i)
<< ") must match the size of tensor b (" << b_extend_shape.At(i)
<< ") at non-singleton dimension " << i;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里异常类型对齐了pytorch

}
return Maybe<void>::Ok();
}

Maybe<Shape> GetBroadcastShape(const Shape& a_shape, const Shape& b_shape) {
Shape broadcast_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes()));
Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), broadcast_shape.NumAxes());
Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), broadcast_shape.NumAxes());
FOR_RANGE(int64_t, i, 0, broadcast_shape.NumAxes()) {
CHECK_OR_RETURN(a_extend_shape.At(i) == 1 || b_extend_shape.At(i) == 1
|| a_extend_shape.At(i) == b_extend_shape.At(i))
<< Error::RuntimeError() << "The size of tensor a (" << a_extend_shape.At(i)
<< ") must match the size of tensor b (" << b_extend_shape.At(i)
<< ") at non-singleton dimension " << i;
broadcast_shape.Set(i, std::max(a_extend_shape.At(i), b_extend_shape.At(i)));
}
return broadcast_shape;
}

Maybe<std::vector<std::tuple<int64_t, int64_t, int64_t, int64_t>>> CalValidSplitDims(
const Shape& a_shape, const Shape& b_shape, const Shape& c_shape) {
wyg1997 marked this conversation as resolved.
Show resolved Hide resolved
JUST(CheckBroadcastable(a_shape, b_shape));
JUST(CheckBroadcastable(a_shape, c_shape));
JUST(CheckBroadcastable(b_shape, c_shape));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的check没太大必要,因为在InferWhereTensorDesc 里实际上已经check过一遍了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的check没太大必要,因为在InferWhereTensorDesc 里实际上已经check过一遍了

确实没有必要,已删除CheckBroadcastable相关逻辑

std::shared_ptr<std::vector<std::tuple<int64_t, int64_t, int64_t, int64_t>>> vaild_split_dims =
std::make_shared<std::vector<std::tuple<int64_t, int64_t, int64_t, int64_t>>>();
int32_t max_num_axes =
std::max(a_shape.NumAxes(), std::max(b_shape.NumAxes(), c_shape.NumAxes()));
Shape broadcast_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes()));
Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), broadcast_shape.NumAxes());
Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), broadcast_shape.NumAxes());
Shape c_extend_shape = CreateLeftExtendedShape(ShapeView(c_shape), broadcast_shape.NumAxes());
int64_t a_dim_offset = max_num_axes - a_shape.NumAxes();
int64_t b_dim_offset = max_num_axes - b_shape.NumAxes();
int64_t c_dim_offset = max_num_axes - c_shape.NumAxes();
FOR_RANGE(int64_t, i, 0, max_num_axes) {
if (a_extend_shape.At(i) != 1 && a_extend_shape.At(i) == b_extend_shape.At(i)
&& a_extend_shape.At(i) == c_extend_shape.At(i)) {
vaild_split_dims->emplace_back(
std::make_tuple(i - a_dim_offset, i - b_dim_offset, i - c_dim_offset, i));
}
}
return vaild_split_dims;
}

Maybe<std::vector<std::tuple<int64_t, int64_t, int64_t>>> CalValidSplitDims(const Shape& a_shape,
const Shape& b_shape) {
JUST(CheckBroadcastable(a_shape, b_shape));
std::shared_ptr<std::vector<std::tuple<int64_t, int64_t, int64_t>>> vaild_split_dims =
std::make_shared<std::vector<std::tuple<int64_t, int64_t, int64_t>>>();
int32_t max_num_axes = std::max(a_shape.NumAxes(), b_shape.NumAxes());
Shape broadcast_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes()));
Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), broadcast_shape.NumAxes());
Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), broadcast_shape.NumAxes());
int64_t a_dim_offset = max_num_axes - a_shape.NumAxes();
int64_t b_dim_offset = max_num_axes - b_shape.NumAxes();
FOR_RANGE(int64_t, i, 0, max_num_axes) {
if (a_extend_shape.At(i) != 1 && a_extend_shape.At(i) == b_extend_shape.At(i)) {
vaild_split_dims->emplace_back(std::make_tuple(i - a_dim_offset, i - b_dim_offset, i));
}
}
return vaild_split_dims;
}

Maybe<void> InferWhereTensorDesc(user_op::InferContext* ctx) {
const Shape& cond_shape = ctx->InputShape("condition", 0);
const Shape& x_shape = ctx->InputShape("x", 0);
const Shape& y_shape = ctx->InputShape("y", 0);
if (x_shape == y_shape && y_shape == cond_shape) {
*ctx->OutputShape("out", 0) = cond_shape;
} else {
Shape max_shape =
Shape::Ones(std::max(x_shape.NumAxes(), std::max(y_shape.NumAxes(), cond_shape.NumAxes())));
const Shape& x_extend_shape = CreateLeftExtendedShape(ShapeView(x_shape), max_shape.NumAxes());
const Shape& y_extend_shape = CreateLeftExtendedShape(ShapeView(y_shape), max_shape.NumAxes());
const Shape& cond_extend_shape =
CreateLeftExtendedShape(ShapeView(cond_shape), max_shape.NumAxes());
FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) {
max_shape.Set(i, std::max(x_extend_shape.At(i),
std::max(y_extend_shape.At(i), cond_extend_shape.At(i))));
}
Shape max_shape = *JUST(GetBroadcastShape(cond_shape, x_shape));
max_shape = *JUST(GetBroadcastShape(max_shape, y_shape));
*ctx->OutputShape("out", 0) = max_shape;
}
return Maybe<void>::Ok();
Expand All @@ -48,13 +114,7 @@ Maybe<void> InferWhereXScalarTensorDesc(user_op::InferContext* ctx) {
if (cond_shape == y_shape) {
*ctx->OutputShape("out", 0) = cond_shape;
} else {
Shape max_shape = Shape::Ones(std::max(y_shape.NumAxes(), cond_shape.NumAxes()));
const Shape& y_extend_shape = CreateLeftExtendedShape(ShapeView(y_shape), max_shape.NumAxes());
const Shape& cond_extend_shape =
CreateLeftExtendedShape(ShapeView(cond_shape), max_shape.NumAxes());
FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) {
max_shape.Set(i, std::max(y_extend_shape.At(i), cond_extend_shape.At(i)));
}
Shape max_shape = *JUST(GetBroadcastShape(cond_shape, y_shape));
*ctx->OutputShape("out", 0) = max_shape;
}
return Maybe<void>::Ok();
Expand All @@ -66,13 +126,7 @@ Maybe<void> InferWhereYScalarTensorDesc(user_op::InferContext* ctx) {
if (cond_shape == x_shape) {
*ctx->OutputShape("out", 0) = cond_shape;
} else {
Shape max_shape = Shape::Ones(std::max(x_shape.NumAxes(), cond_shape.NumAxes()));
const Shape& x_extend_shape = CreateLeftExtendedShape(ShapeView(x_shape), max_shape.NumAxes());
const Shape& cond_extend_shape =
CreateLeftExtendedShape(ShapeView(cond_shape), max_shape.NumAxes());
FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) {
max_shape.Set(i, std::max(x_extend_shape.At(i), cond_extend_shape.At(i)));
}
Shape max_shape = *JUST(GetBroadcastShape(cond_shape, x_shape));
*ctx->OutputShape("out", 0) = max_shape;
}
return Maybe<void>::Ok();
Expand All @@ -84,14 +138,16 @@ Maybe<void> InferWhereXYScalarTensorDesc(user_op::InferContext* ctx) {
}

Maybe<void> GetWhereSbpSignatures(user_op::SbpContext* ctx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 op 因为 condition、x、y 都是 broadcastable 的,这里的 sbp 推导可能还不完全正确?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 op 因为 condition、x、y 都是 broadcastable 的,这里的 sbp 推导可能还不完全正确?

确实不完善,我在refine一下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里的推导逻辑会比较复杂,看是不是可以简化一下。比如只有condition、x、y shape相同时才支持所有输入的split,如果有两个相同,另一个不同,那就参考broadcast_add那里的逻辑,如果三个都不同,那就只支持B

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里的推导逻辑会比较复杂,看是不是可以简化一下。比如只有condition、x、y shape相同时才支持所有输入的split,如果有两个相同,另一个不同,那就参考broadcast_add那里的逻辑,如果三个都不同,那就只支持B

也不是太复杂,已经完善好了,应该可以覆盖所有情况, @wyg1997 @hjchen2 有时间review一下这一部分的修改

const user_op::TensorDesc& condition_tensor =
ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0);
FOR_RANGE(int64_t, i, 0, condition_tensor.shape().NumAxes()) {
const Shape& cond_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0).shape();
const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape();
const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape();
const auto& vaild_split_dims = JUST(CalValidSplitDims(cond_shape, x_shape, y_shape));
for (const auto& vaild_split_dim : *vaild_split_dims) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), i)
.Split(user_op::OpArg("x", 0), i)
.Split(user_op::OpArg("y", 0), i)
.Split(user_op::OpArg("out", 0), i)
.Split(user_op::OpArg("condition", 0), std::get<0>(vaild_split_dim))
.Split(user_op::OpArg("x", 0), std::get<1>(vaild_split_dim))
.Split(user_op::OpArg("y", 0), std::get<2>(vaild_split_dim))
.Split(user_op::OpArg("out", 0), std::get<3>(vaild_split_dim))
.Build();
}
ctx->NewBuilder()
Expand All @@ -106,57 +162,13 @@ Maybe<void> GetWhereSbpSignatures(user_op::SbpContext* ctx) {
Maybe<void> GetWhereXScalarSbpSignatures(user_op::SbpContext* ctx) {
const Shape& cond_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0).shape();
const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape();
if (cond_shape.NumAxes() < y_shape.NumAxes()) {
FOR_RANGE(int64_t, i, 0, y_shape.NumAxes() - cond_shape.NumAxes()) {
ctx->NewBuilder()
.Broadcast(user_op::OpArg("condition", 0))
.Split(user_op::OpArg("y", 0), i)
.Split(user_op::OpArg("out", 0), i)
.Build();
}
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), cond_shape.NumAxes() - 1 - i)
.Split(user_op::OpArg("y", 0), y_shape.NumAxes() - 1 - i)
.Split(ctx->outputs(), y_shape.NumAxes() - 1 - i)
.Build();
}
} else if (cond_shape.NumAxes() > y_shape.NumAxes()) {
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes() - y_shape.NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), i)
.Broadcast(user_op::OpArg("y", 0))
.Split(user_op::OpArg("out", 0), i)
.Build();
}
FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), cond_shape.NumAxes() - 1 - i)
.Split(user_op::OpArg("y", 0), y_shape.NumAxes() - 1 - i)
.Split(ctx->outputs(), cond_shape.NumAxes() - 1 - i)
.Build();
}
} else {
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes()) {
if (cond_shape.At(i) == 1 && y_shape.At(i) == 1) { continue; }
if (cond_shape.At(i) == y_shape.At(i)) {
ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
} else if (cond_shape.At(i) == 1) {
ctx->NewBuilder()
.Broadcast(user_op::OpArg("condition", 0))
.Split(user_op::OpArg("y", 0), i)
.Split(ctx->outputs(), i)
.Build();
} else if (y_shape.At(i) == 1) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), i)
.Broadcast(user_op::OpArg("y", 0))
.Split(ctx->outputs(), i)
.Build();
} else {
UNIMPLEMENTED();
}
}
const auto& vaild_split_dims = JUST(CalValidSplitDims(cond_shape, y_shape));
for (const auto& vaild_split_dim : *vaild_split_dims) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), std::get<0>(vaild_split_dim))
.Split(user_op::OpArg("y", 0), std::get<1>(vaild_split_dim))
.Split(user_op::OpArg("out", 0), std::get<2>(vaild_split_dim))
.Build();
}
ctx->NewBuilder()
.Broadcast(user_op::OpArg("condition", 0))
Expand All @@ -169,57 +181,13 @@ Maybe<void> GetWhereXScalarSbpSignatures(user_op::SbpContext* ctx) {
Maybe<void> GetWhereYScalarSbpSignatures(user_op::SbpContext* ctx) {
const Shape& cond_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0).shape();
const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape();
if (cond_shape.NumAxes() < x_shape.NumAxes()) {
FOR_RANGE(int64_t, i, 0, x_shape.NumAxes() - cond_shape.NumAxes()) {
ctx->NewBuilder()
.Broadcast(user_op::OpArg("condition", 0))
.Split(user_op::OpArg("x", 0), i)
.Split(user_op::OpArg("out", 0), i)
.Build();
}
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), cond_shape.NumAxes() - 1 - i)
.Split(user_op::OpArg("x", 0), x_shape.NumAxes() - 1 - i)
.Split(ctx->outputs(), x_shape.NumAxes() - 1 - i)
.Build();
}
} else if (cond_shape.NumAxes() > x_shape.NumAxes()) {
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes() - x_shape.NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), i)
.Broadcast(user_op::OpArg("x", 0))
.Split(user_op::OpArg("out", 0), i)
.Build();
}
FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), cond_shape.NumAxes() - 1 - i)
.Split(user_op::OpArg("x", 0), x_shape.NumAxes() - 1 - i)
.Split(ctx->outputs(), cond_shape.NumAxes() - 1 - i)
.Build();
}
} else {
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes()) {
if (cond_shape.At(i) == 1 && x_shape.At(i) == 1) { continue; }
if (cond_shape.At(i) == x_shape.At(i)) {
ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
} else if (cond_shape.At(i) == 1) {
ctx->NewBuilder()
.Broadcast(user_op::OpArg("condition", 0))
.Split(user_op::OpArg("x", 0), i)
.Split(ctx->outputs(), i)
.Build();
} else if (x_shape.At(i) == 1) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), i)
.Broadcast(user_op::OpArg("x", 0))
.Split(ctx->outputs(), i)
.Build();
} else {
UNIMPLEMENTED();
}
}
const auto& vaild_split_dims = JUST(CalValidSplitDims(cond_shape, x_shape));
for (const auto& vaild_split_dim : *vaild_split_dims) {
ctx->NewBuilder()
.Split(user_op::OpArg("condition", 0), std::get<0>(vaild_split_dim))
.Split(user_op::OpArg("x", 0), std::get<1>(vaild_split_dim))
.Split(user_op::OpArg("out", 0), std::get<2>(vaild_split_dim))
.Build();
}
ctx->NewBuilder()
.Broadcast(user_op::OpArg("condition", 0))
Expand Down
55 changes: 55 additions & 0 deletions python/oneflow/test/modules/test_consistent_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from oneflow.test_utils.automated_test_util import *

import oneflow as flow
import oneflow.unittest


@autotest(n=1, check_graph=False)
def _test_global_view(test_case, placement, sbp):
x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)
y = x.view(2, 2, 2, -1)
return y


@autotest(n=1, check_graph=False)
def _test_global_view_size(test_case, placement, sbp):
x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp)
shape = torch.Size([2, 2, 2, -1])
y = x.view(shape)
return y


class TestGlobalView(flow.unittest.TestCase):
@globaltest
def test_global_view(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=2):
_test_global_view(test_case, placement, sbp)

@globaltest
def test_global_view_size(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=2):
_test_global_view_size(test_case, placement, sbp)


if __name__ == "__main__":
unittest.main()
42 changes: 42 additions & 0 deletions python/oneflow/test/modules/test_consistent_weight_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict

from oneflow.test_utils.test_util import GenArgList
from oneflow.test_utils.automated_test_util import *
import oneflow as flow
import oneflow.unittest


@autotest(n=1, check_graph=False)
def _test_global_weight_norm_with_random_data(test_case, placement, sbp):
dim = random(-2, 2).to(int).value()
liner_model_torch = torch.nn.Linear(8, 16).to_global(placement, sbp)
m = torch.nn.utils.weight_norm(liner_model_torch, name="weight", dim=dim)
return m.weight_g, m.weight_v


class TestGlobalWeightNorm(flow.unittest.TestCase):
@globaltest
def test_global_weight_norm_with_random_data(test_case):
for placement in all_placement():
for sbp in all_sbp(placement, max_dim=1):
_test_global_weight_norm_with_random_data(test_case, placement, sbp)


if __name__ == "__main__":
unittest.main()
Loading