-
Notifications
You must be signed in to change notification settings - Fork 685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add view weight norm where zeropad2d global test #7886
Add view weight norm where zeropad2d global test #7886
Conversation
…add_view_weight_norm_where_zeropad2d_global_test
@@ -86,7 +86,10 @@ Maybe<void> InferWhereXYScalarTensorDesc(user_op::InferContext* ctx) { | |||
Maybe<void> GetWhereSbpSignatures(user_op::SbpContext* ctx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 op 因为 condition、x、y 都是 broadcastable 的,这里的 sbp 推导可能还不完全正确?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 op 因为 condition、x、y 都是 broadcastable 的,这里的 sbp 推导可能还不完全正确?
确实不完善,我在refine一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,这里的推导逻辑会比较复杂,看是不是可以简化一下。比如只有condition、x、y shape相同时才支持所有输入的split,如果有两个相同,另一个不同,那就参考broadcast_add那里的逻辑,如果三个都不同,那就只支持B
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…add_view_weight_norm_where_zeropad2d_global_test
…tps://github.com/Oneflow-Inc/oneflow into add_view_weight_norm_where_zeropad2d_global_test
oneflow/user/ops/where_op.cpp
Outdated
CHECK_OR_RETURN(a_extend_shape.At(i) == 1 || b_extend_shape.At(i) == 1 | ||
|| a_extend_shape.At(i) == b_extend_shape.At(i)) | ||
<< Error::RuntimeError() << "The size of tensor a (" << a_extend_shape.At(i) | ||
<< ") must match the size of tensor b (" << b_extend_shape.At(i) | ||
<< ") at non-singleton dimension " << i; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里异常类型对齐了pytorch
oneflow/user/ops/where_op.cpp
Outdated
JUST(CheckBroadcastable(a_shape, b_shape)); | ||
JUST(CheckBroadcastable(a_shape, c_shape)); | ||
JUST(CheckBroadcastable(b_shape, c_shape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的check没太大必要,因为在InferWhereTensorDesc
里实际上已经check过一遍了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的check没太大必要,因为在
InferWhereTensorDesc
里实际上已经check过一遍了
确实没有必要,已删除CheckBroadcastable
相关逻辑
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7886/ |
CI failed when running job: cpu-misc. PR label automerge has been removed |
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7886/ |
Speed stats:
|
No description provided.