-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prim][PIR] binary_cross_entropy_with_logits forward decomp #61613
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Sorry to inform you that 4950de4's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
full<T>(common::vectorize(dims), ignore_index, label.type()); | ||
auto out = where<T>(label == ignore_index_tensor, zero, tmp_out); | ||
if (normalize) { | ||
const Tensor eps1 = full<T>(common::vectorize(dims), 1e-6, x.type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1e-6 怎么来的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议代码行注释一下来源
const Tensor eps1 = full<T>(common::vectorize(dims), 1e-6, x.type()); | ||
auto diff = label - ignore_index_tensor; | ||
const Tensor tmp_norm = sum<T>(where<T>(abs<T>(diff) > eps1, one, zero)); | ||
const Tensor eps2 = full<T>(empty_shape, 1e-5, x.type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1e-5同上?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
test_check_grad 里也应该添加check_prim_pir=True |
单侧shape有点小,建议增大 |
Sorry to inform you that b72d42d's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
Sorry to inform you that 02a9896's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
value = value * get_slice<T>(x_shape_tensor, i); | ||
} | ||
value = reshape<T>(value, {}); | ||
ans = sum<T>(x_cast) / cast<T>(value, DataType::FLOAT32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再review cast逻辑,以及dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
…ddle#61613) * sigmoid_cross_entropy_with_logits forward decomp * mean_all forward decomp * add the test case for binary_cross_entropy_with_logits * creat a new test file * modify the assert method * modify the test * fix code style * add prim in check grad for test and handle the optional tensor * fix conflict * do not modify the third_party package * fix merge bug * modfiy the test data and change the file name * roback * fix bug * support mean_all for dynamic shape * modify the type
PR Category
Operator Mechanism
PR Types
Others
Description
binary_cross_entropy_with_logits forward decomp
subtasks: