-
Notifications
You must be signed in to change notification settings - Fork 720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add view weight norm where zeropad2d global test #7886
Changes from 10 commits
9a40937
8ff6ec6
d38ba32
8ebbb96
ef41138
984e07f
6a45f0b
0a5e7e2
8240592
91e03c7
b19fda5
ad91472
9415e21
2f0f699
07b49b0
a639e5e
9d065ba
3cd4215
f4d975b
eaedcbb
380edff
2a629d7
4615260
20900e2
b984a24
04bce37
4c2d731
6380dd8
03f7e48
427ad76
41624d9
f61f8ab
4338124
bcc446c
f4e6029
9a4f396
b2a53d0
bc88cbf
9b91072
0c58fa2
771b0a7
c1c8db1
51df93a
3c5f66e
0450ab4
3231d36
f3d3e41
efcbfda
6dc72e8
407d4b0
94d0193
bff4b02
b4e0b98
ac5627b
82fbe6e
49b1410
3110802
b825200
d12c01b
98f6f2a
7cbb7d5
faa6fb1
a2bdd5a
158e94f
2764db4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,23 +20,89 @@ namespace oneflow { | |
|
||
namespace { | ||
|
||
Maybe<void> CheckBroadcastable(const Shape& a_shape, const Shape& b_shape) { | ||
Shape broadcast_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes())); | ||
Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), broadcast_shape.NumAxes()); | ||
Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), broadcast_shape.NumAxes()); | ||
FOR_RANGE(int64_t, i, 0, broadcast_shape.NumAxes()) { | ||
CHECK_OR_RETURN(a_extend_shape.At(i) == 1 || b_extend_shape.At(i) == 1 | ||
|| a_extend_shape.At(i) == b_extend_shape.At(i)) | ||
<< Error::RuntimeError() << "The size of tensor a (" << a_extend_shape.At(i) | ||
<< ") must match the size of tensor b (" << b_extend_shape.At(i) | ||
<< ") at non-singleton dimension " << i; | ||
} | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<Shape> GetBroadcastShape(const Shape& a_shape, const Shape& b_shape) { | ||
Shape broadcast_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes())); | ||
Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), broadcast_shape.NumAxes()); | ||
Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), broadcast_shape.NumAxes()); | ||
FOR_RANGE(int64_t, i, 0, broadcast_shape.NumAxes()) { | ||
CHECK_OR_RETURN(a_extend_shape.At(i) == 1 || b_extend_shape.At(i) == 1 | ||
|| a_extend_shape.At(i) == b_extend_shape.At(i)) | ||
<< Error::RuntimeError() << "The size of tensor a (" << a_extend_shape.At(i) | ||
<< ") must match the size of tensor b (" << b_extend_shape.At(i) | ||
<< ") at non-singleton dimension " << i; | ||
broadcast_shape.Set(i, std::max(a_extend_shape.At(i), b_extend_shape.At(i))); | ||
} | ||
return broadcast_shape; | ||
} | ||
|
||
Maybe<std::vector<std::tuple<int64_t, int64_t, int64_t, int64_t>>> CalValidSplitDims( | ||
const Shape& a_shape, const Shape& b_shape, const Shape& c_shape) { | ||
wyg1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
JUST(CheckBroadcastable(a_shape, b_shape)); | ||
JUST(CheckBroadcastable(a_shape, c_shape)); | ||
JUST(CheckBroadcastable(b_shape, c_shape)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的check没太大必要,因为在 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
确实没有必要,已删除 |
||
std::shared_ptr<std::vector<std::tuple<int64_t, int64_t, int64_t, int64_t>>> vaild_split_dims = | ||
std::make_shared<std::vector<std::tuple<int64_t, int64_t, int64_t, int64_t>>>(); | ||
int32_t max_num_axes = | ||
std::max(a_shape.NumAxes(), std::max(b_shape.NumAxes(), c_shape.NumAxes())); | ||
Shape broadcast_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes())); | ||
Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), broadcast_shape.NumAxes()); | ||
Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), broadcast_shape.NumAxes()); | ||
Shape c_extend_shape = CreateLeftExtendedShape(ShapeView(c_shape), broadcast_shape.NumAxes()); | ||
int64_t a_dim_offset = max_num_axes - a_shape.NumAxes(); | ||
int64_t b_dim_offset = max_num_axes - b_shape.NumAxes(); | ||
int64_t c_dim_offset = max_num_axes - c_shape.NumAxes(); | ||
FOR_RANGE(int64_t, i, 0, max_num_axes) { | ||
if (a_extend_shape.At(i) != 1 && a_extend_shape.At(i) == b_extend_shape.At(i) | ||
&& a_extend_shape.At(i) == c_extend_shape.At(i)) { | ||
vaild_split_dims->emplace_back( | ||
std::make_tuple(i - a_dim_offset, i - b_dim_offset, i - c_dim_offset, i)); | ||
} | ||
} | ||
return vaild_split_dims; | ||
} | ||
|
||
Maybe<std::vector<std::tuple<int64_t, int64_t, int64_t>>> CalValidSplitDims(const Shape& a_shape, | ||
const Shape& b_shape) { | ||
JUST(CheckBroadcastable(a_shape, b_shape)); | ||
std::shared_ptr<std::vector<std::tuple<int64_t, int64_t, int64_t>>> vaild_split_dims = | ||
std::make_shared<std::vector<std::tuple<int64_t, int64_t, int64_t>>>(); | ||
int32_t max_num_axes = std::max(a_shape.NumAxes(), b_shape.NumAxes()); | ||
Shape broadcast_shape = Shape::Ones(std::max(a_shape.NumAxes(), b_shape.NumAxes())); | ||
Shape a_extend_shape = CreateLeftExtendedShape(ShapeView(a_shape), broadcast_shape.NumAxes()); | ||
Shape b_extend_shape = CreateLeftExtendedShape(ShapeView(b_shape), broadcast_shape.NumAxes()); | ||
int64_t a_dim_offset = max_num_axes - a_shape.NumAxes(); | ||
int64_t b_dim_offset = max_num_axes - b_shape.NumAxes(); | ||
FOR_RANGE(int64_t, i, 0, max_num_axes) { | ||
if (a_extend_shape.At(i) != 1 && a_extend_shape.At(i) == b_extend_shape.At(i)) { | ||
vaild_split_dims->emplace_back(std::make_tuple(i - a_dim_offset, i - b_dim_offset, i)); | ||
} | ||
} | ||
return vaild_split_dims; | ||
} | ||
|
||
Maybe<void> InferWhereTensorDesc(user_op::InferContext* ctx) { | ||
const Shape& cond_shape = ctx->InputShape("condition", 0); | ||
const Shape& x_shape = ctx->InputShape("x", 0); | ||
const Shape& y_shape = ctx->InputShape("y", 0); | ||
if (x_shape == y_shape && y_shape == cond_shape) { | ||
*ctx->OutputShape("out", 0) = cond_shape; | ||
} else { | ||
Shape max_shape = | ||
Shape::Ones(std::max(x_shape.NumAxes(), std::max(y_shape.NumAxes(), cond_shape.NumAxes()))); | ||
const Shape& x_extend_shape = CreateLeftExtendedShape(ShapeView(x_shape), max_shape.NumAxes()); | ||
const Shape& y_extend_shape = CreateLeftExtendedShape(ShapeView(y_shape), max_shape.NumAxes()); | ||
const Shape& cond_extend_shape = | ||
CreateLeftExtendedShape(ShapeView(cond_shape), max_shape.NumAxes()); | ||
FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) { | ||
max_shape.Set(i, std::max(x_extend_shape.At(i), | ||
std::max(y_extend_shape.At(i), cond_extend_shape.At(i)))); | ||
} | ||
Shape max_shape = *JUST(GetBroadcastShape(cond_shape, x_shape)); | ||
max_shape = *JUST(GetBroadcastShape(max_shape, y_shape)); | ||
*ctx->OutputShape("out", 0) = max_shape; | ||
} | ||
return Maybe<void>::Ok(); | ||
|
@@ -48,13 +114,7 @@ Maybe<void> InferWhereXScalarTensorDesc(user_op::InferContext* ctx) { | |
if (cond_shape == y_shape) { | ||
*ctx->OutputShape("out", 0) = cond_shape; | ||
} else { | ||
Shape max_shape = Shape::Ones(std::max(y_shape.NumAxes(), cond_shape.NumAxes())); | ||
const Shape& y_extend_shape = CreateLeftExtendedShape(ShapeView(y_shape), max_shape.NumAxes()); | ||
const Shape& cond_extend_shape = | ||
CreateLeftExtendedShape(ShapeView(cond_shape), max_shape.NumAxes()); | ||
FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) { | ||
max_shape.Set(i, std::max(y_extend_shape.At(i), cond_extend_shape.At(i))); | ||
} | ||
Shape max_shape = *JUST(GetBroadcastShape(cond_shape, y_shape)); | ||
*ctx->OutputShape("out", 0) = max_shape; | ||
} | ||
return Maybe<void>::Ok(); | ||
|
@@ -66,13 +126,7 @@ Maybe<void> InferWhereYScalarTensorDesc(user_op::InferContext* ctx) { | |
if (cond_shape == x_shape) { | ||
*ctx->OutputShape("out", 0) = cond_shape; | ||
} else { | ||
Shape max_shape = Shape::Ones(std::max(x_shape.NumAxes(), cond_shape.NumAxes())); | ||
const Shape& x_extend_shape = CreateLeftExtendedShape(ShapeView(x_shape), max_shape.NumAxes()); | ||
const Shape& cond_extend_shape = | ||
CreateLeftExtendedShape(ShapeView(cond_shape), max_shape.NumAxes()); | ||
FOR_RANGE(int64_t, i, 0, max_shape.NumAxes()) { | ||
max_shape.Set(i, std::max(x_extend_shape.At(i), cond_extend_shape.At(i))); | ||
} | ||
Shape max_shape = *JUST(GetBroadcastShape(cond_shape, x_shape)); | ||
*ctx->OutputShape("out", 0) = max_shape; | ||
} | ||
return Maybe<void>::Ok(); | ||
|
@@ -84,14 +138,16 @@ Maybe<void> InferWhereXYScalarTensorDesc(user_op::InferContext* ctx) { | |
} | ||
|
||
Maybe<void> GetWhereSbpSignatures(user_op::SbpContext* ctx) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个 op 因为 condition、x、y 都是 broadcastable 的,这里的 sbp 推导可能还不完全正确? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
确实不完善,我在refine一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,这里的推导逻辑会比较复杂,看是不是可以简化一下。比如只有condition、x、y shape相同时才支持所有输入的split,如果有两个相同,另一个不同,那就参考broadcast_add那里的逻辑,如果三个都不同,那就只支持B There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
const user_op::TensorDesc& condition_tensor = | ||
ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0); | ||
FOR_RANGE(int64_t, i, 0, condition_tensor.shape().NumAxes()) { | ||
const Shape& cond_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0).shape(); | ||
const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); | ||
const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); | ||
const auto& vaild_split_dims = JUST(CalValidSplitDims(cond_shape, x_shape, y_shape)); | ||
for (const auto& vaild_split_dim : *vaild_split_dims) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), i) | ||
.Split(user_op::OpArg("x", 0), i) | ||
.Split(user_op::OpArg("y", 0), i) | ||
.Split(user_op::OpArg("out", 0), i) | ||
.Split(user_op::OpArg("condition", 0), std::get<0>(vaild_split_dim)) | ||
.Split(user_op::OpArg("x", 0), std::get<1>(vaild_split_dim)) | ||
.Split(user_op::OpArg("y", 0), std::get<2>(vaild_split_dim)) | ||
.Split(user_op::OpArg("out", 0), std::get<3>(vaild_split_dim)) | ||
.Build(); | ||
} | ||
ctx->NewBuilder() | ||
|
@@ -106,57 +162,13 @@ Maybe<void> GetWhereSbpSignatures(user_op::SbpContext* ctx) { | |
Maybe<void> GetWhereXScalarSbpSignatures(user_op::SbpContext* ctx) { | ||
const Shape& cond_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0).shape(); | ||
const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); | ||
if (cond_shape.NumAxes() < y_shape.NumAxes()) { | ||
FOR_RANGE(int64_t, i, 0, y_shape.NumAxes() - cond_shape.NumAxes()) { | ||
ctx->NewBuilder() | ||
.Broadcast(user_op::OpArg("condition", 0)) | ||
.Split(user_op::OpArg("y", 0), i) | ||
.Split(user_op::OpArg("out", 0), i) | ||
.Build(); | ||
} | ||
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes()) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), cond_shape.NumAxes() - 1 - i) | ||
.Split(user_op::OpArg("y", 0), y_shape.NumAxes() - 1 - i) | ||
.Split(ctx->outputs(), y_shape.NumAxes() - 1 - i) | ||
.Build(); | ||
} | ||
} else if (cond_shape.NumAxes() > y_shape.NumAxes()) { | ||
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes() - y_shape.NumAxes()) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), i) | ||
.Broadcast(user_op::OpArg("y", 0)) | ||
.Split(user_op::OpArg("out", 0), i) | ||
.Build(); | ||
} | ||
FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), cond_shape.NumAxes() - 1 - i) | ||
.Split(user_op::OpArg("y", 0), y_shape.NumAxes() - 1 - i) | ||
.Split(ctx->outputs(), cond_shape.NumAxes() - 1 - i) | ||
.Build(); | ||
} | ||
} else { | ||
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes()) { | ||
if (cond_shape.At(i) == 1 && y_shape.At(i) == 1) { continue; } | ||
if (cond_shape.At(i) == y_shape.At(i)) { | ||
ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); | ||
} else if (cond_shape.At(i) == 1) { | ||
ctx->NewBuilder() | ||
.Broadcast(user_op::OpArg("condition", 0)) | ||
.Split(user_op::OpArg("y", 0), i) | ||
.Split(ctx->outputs(), i) | ||
.Build(); | ||
} else if (y_shape.At(i) == 1) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), i) | ||
.Broadcast(user_op::OpArg("y", 0)) | ||
.Split(ctx->outputs(), i) | ||
.Build(); | ||
} else { | ||
UNIMPLEMENTED(); | ||
} | ||
} | ||
const auto& vaild_split_dims = JUST(CalValidSplitDims(cond_shape, y_shape)); | ||
for (const auto& vaild_split_dim : *vaild_split_dims) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), std::get<0>(vaild_split_dim)) | ||
.Split(user_op::OpArg("y", 0), std::get<1>(vaild_split_dim)) | ||
.Split(user_op::OpArg("out", 0), std::get<2>(vaild_split_dim)) | ||
.Build(); | ||
} | ||
ctx->NewBuilder() | ||
.Broadcast(user_op::OpArg("condition", 0)) | ||
|
@@ -169,57 +181,13 @@ Maybe<void> GetWhereXScalarSbpSignatures(user_op::SbpContext* ctx) { | |
Maybe<void> GetWhereYScalarSbpSignatures(user_op::SbpContext* ctx) { | ||
const Shape& cond_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("condition", 0).shape(); | ||
const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); | ||
if (cond_shape.NumAxes() < x_shape.NumAxes()) { | ||
FOR_RANGE(int64_t, i, 0, x_shape.NumAxes() - cond_shape.NumAxes()) { | ||
ctx->NewBuilder() | ||
.Broadcast(user_op::OpArg("condition", 0)) | ||
.Split(user_op::OpArg("x", 0), i) | ||
.Split(user_op::OpArg("out", 0), i) | ||
.Build(); | ||
} | ||
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes()) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), cond_shape.NumAxes() - 1 - i) | ||
.Split(user_op::OpArg("x", 0), x_shape.NumAxes() - 1 - i) | ||
.Split(ctx->outputs(), x_shape.NumAxes() - 1 - i) | ||
.Build(); | ||
} | ||
} else if (cond_shape.NumAxes() > x_shape.NumAxes()) { | ||
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes() - x_shape.NumAxes()) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), i) | ||
.Broadcast(user_op::OpArg("x", 0)) | ||
.Split(user_op::OpArg("out", 0), i) | ||
.Build(); | ||
} | ||
FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), cond_shape.NumAxes() - 1 - i) | ||
.Split(user_op::OpArg("x", 0), x_shape.NumAxes() - 1 - i) | ||
.Split(ctx->outputs(), cond_shape.NumAxes() - 1 - i) | ||
.Build(); | ||
} | ||
} else { | ||
FOR_RANGE(int64_t, i, 0, cond_shape.NumAxes()) { | ||
if (cond_shape.At(i) == 1 && x_shape.At(i) == 1) { continue; } | ||
if (cond_shape.At(i) == x_shape.At(i)) { | ||
ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); | ||
} else if (cond_shape.At(i) == 1) { | ||
ctx->NewBuilder() | ||
.Broadcast(user_op::OpArg("condition", 0)) | ||
.Split(user_op::OpArg("x", 0), i) | ||
.Split(ctx->outputs(), i) | ||
.Build(); | ||
} else if (x_shape.At(i) == 1) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), i) | ||
.Broadcast(user_op::OpArg("x", 0)) | ||
.Split(ctx->outputs(), i) | ||
.Build(); | ||
} else { | ||
UNIMPLEMENTED(); | ||
} | ||
} | ||
const auto& vaild_split_dims = JUST(CalValidSplitDims(cond_shape, x_shape)); | ||
for (const auto& vaild_split_dim : *vaild_split_dims) { | ||
ctx->NewBuilder() | ||
.Split(user_op::OpArg("condition", 0), std::get<0>(vaild_split_dim)) | ||
.Split(user_op::OpArg("x", 0), std::get<1>(vaild_split_dim)) | ||
.Split(user_op::OpArg("out", 0), std::get<2>(vaild_split_dim)) | ||
.Build(); | ||
} | ||
ctx->NewBuilder() | ||
.Broadcast(user_op::OpArg("condition", 0)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import unittest | ||
|
||
from oneflow.test_utils.automated_test_util import * | ||
|
||
import oneflow as flow | ||
import oneflow.unittest | ||
|
||
|
||
@autotest(n=1, check_graph=False) | ||
def _test_global_view(test_case, placement, sbp): | ||
x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) | ||
y = x.view(2, 2, 2, -1) | ||
return y | ||
|
||
|
||
@autotest(n=1, check_graph=False) | ||
def _test_global_view_size(test_case, placement, sbp): | ||
x = random_tensor(ndim=2, dim0=8, dim1=16).to_global(placement, sbp) | ||
shape = torch.Size([2, 2, 2, -1]) | ||
y = x.view(shape) | ||
return y | ||
|
||
|
||
class TestGlobalView(flow.unittest.TestCase): | ||
@globaltest | ||
def test_global_view(test_case): | ||
for placement in all_placement(): | ||
for sbp in all_sbp(placement, max_dim=2): | ||
_test_global_view(test_case, placement, sbp) | ||
|
||
@globaltest | ||
def test_global_view_size(test_case): | ||
for placement in all_placement(): | ||
for sbp in all_sbp(placement, max_dim=2): | ||
_test_global_view_size(test_case, placement, sbp) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import unittest | ||
from collections import OrderedDict | ||
|
||
from oneflow.test_utils.test_util import GenArgList | ||
from oneflow.test_utils.automated_test_util import * | ||
import oneflow as flow | ||
import oneflow.unittest | ||
|
||
|
||
@autotest(n=1, check_graph=False) | ||
def _test_global_weight_norm_with_random_data(test_case, placement, sbp): | ||
dim = random(-2, 2).to(int).value() | ||
liner_model_torch = torch.nn.Linear(8, 16).to_global(placement, sbp) | ||
m = torch.nn.utils.weight_norm(liner_model_torch, name="weight", dim=dim) | ||
return m.weight_g, m.weight_v | ||
|
||
|
||
class TestGlobalWeightNorm(flow.unittest.TestCase): | ||
@globaltest | ||
def test_global_weight_norm_with_random_data(test_case): | ||
for placement in all_placement(): | ||
for sbp in all_sbp(placement, max_dim=1): | ||
_test_global_weight_norm_with_random_data(test_case, placement, sbp) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里异常类型对齐了pytorch