-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[metal] fix_elementwise #7467
[metal] fix_elementwise #7467
Conversation
auto op_type = KernelBase::op_type(); | ||
op_ = op_type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接取值, ele_type_ = KernelBase::op_type()
std::shared_ptr<MetalBuffer> params_buffer_; | ||
DDim last_input_dims_{}; | ||
|
||
std::string op_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成ele_type_
id<MTLComputePipelineState> pipline_; | ||
std::string function_name_; | ||
MetalContext* metal_context_; | ||
|
||
int op_num; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用准确的名字,operation_type_
bool InputsValid(const MetalImage* input_x_, const MetalImage* input_y_) { | ||
auto x_dims = input_x_->dim_; | ||
auto y_dims = input_y_->dim_; | ||
|
||
// check data layout | ||
if (input_x_->transpose_ != input_y_->transpose_) return false; | ||
// check data dims equal | ||
if (x_dims == y_dims) return true; | ||
|
||
if (x_dims[0] == y_dims[0] && x_dims[3] == y_dims[3]) { | ||
//[1 32 1 3] | ||
if (x_dims[1] == y_dims[1] && (x_dims[2] == 1 || y_dims[2] == 1)) return true; | ||
//[1 1 32 3] | ||
if (x_dims[2] == y_dims[2] && (x_dims[1] == 1 || y_dims[1] == 1)) return true; | ||
//[1 1 1 3] | ||
if ((x_dims[1] == 1 && x_dims[2] == 1) || (y_dims[1] == 1 && y_dims[2] == 1)) return true; | ||
} | ||
return false; | ||
} | ||
|
||
void ElementwiseImageCompute::PrepareForRun() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码格式化一下
|
||
params_buffer_ = | ||
std::make_shared<MetalBuffer>(metal_context_, sizeof(element_params), &element_params); | ||
|
||
function_name_ = fuse_flag_ ? "elementwise_add_relu" : "elementwise_add"; | ||
|
||
function_name_ = fuse_flag_ ? "elementwise_relu" : "elementwise_"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elementwise_改成elementwise
* fix_elementwise * fix_elementwise
针对metalelementwise系列代码包括进行了重写和整合,其中包括:
add、sub、mul、div全面支持mps框架,支持broadcast
对elementwise代码进行了整合,目前通过ElementwiseImageCompute单个类的调用即可