-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[oneDNN] Reshape attr_axes when going to oneDNN kernel #59641
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
This change is kind of tricky... And should find better way to solve it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not every time need this transformation, can we use it based on get_cur_paddle_data_layout()? It seems only happens in get_cur_paddle_data_layout() equals kNHWC or NHWC?
|
||
// Currently there is only transformation for tensors, while attr axes still | ||
// follows default dtype instead of oneDNN dtype, so here manually change it | ||
// TODO(Li Xinyi): Is there a way to achieve it at earlier stage? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not every time need this transformation, can we use it based on get_cur_paddle_data_layout()? It seems only happens in get_cur_paddle_data_layout() equals kNHWC or NHWC?
// Currently there is only transformation for tensors, while attr axes still | ||
// follows default dtype instead of oneDNN dtype, so here manually change it | ||
// TODO(Li Xinyi): Is there a way to achieve it at earlier stage? | ||
funcs::GetSqueezeNewAxes(tmp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<int64_t> formated_axis = axes.GetData();
if ((x_dims.size() >= 3) &&
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC)) {
int axis_size = axes.GetData().size();
for (int i = 0; i < axis_size; i++) {
if (axes.GetData()[i] < 0) {
formated_axis[i] = axes.GetData()[i] + axis_size;
}
}
}
std::vector<int32_t> tmp(formated_axis.begin(), formated_axis.end());
It seems can use such method to solve it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<int64_t> formated_axis = axes.GetData(); if ((x_dims.size() >= 3) && (phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC)) { int axis_size = axes.GetData().size(); for (int i = 0; i < axis_size; i++) { if (axes.GetData()[i] < 0) { formated_axis[i] = axes.GetData()[i] + axis_size; } } } std::vector<int32_t> tmp(formated_axis.begin(), formated_axis.end());
It seems can use such method to solve it?
Thanks for your advice! Do you think it's better to add check function when get_cur_paddle_data_layout() == phi::DataLayout::NDHWC
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/transfer_layout_kernel.cc#L161, maybe still need to kNHWC or NHWC? (kNHWC define equals NHWC)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/transfer_layout_kernel.cc#L161, maybe still need to kNHWC or NHWC? (kNHWC define equals NHWC)
Seems it assumed the dims would only be 4... Here indicates 5D dims. So maybe add one more case is also good?
Line 36 in 9d7150b
enum class DataLayout { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I think it's more reasonable
Hi @xuxinyi389, is it you opened cuda check for PR? Since I saw there is only you (the only xuxinyi) can be at… If so, would you mind helping check this CI issue? Thanks! |
|
So it won’t block merge? Thanks! @yuanlehome @xinyu-intel @vivienfanghuagood would you mind helping check this PR? Thanks~ |
PR types
Bug fixes
PR changes
Others
Description
This PR aims to fix Conv1DTranspose in #59510.
When Squeeze op is dispatched to oneDNN, corresponding Tensors will be transformed to oneDNN DType. But currently there is only transformation for tensors, attr_axes still follows default dtype instead of oneDNN DType, so here we change it to well performs Squeeze.