-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support generating code for grad_op #21066
Conversation
… for grad_op. Refine the unittest for better reuse.
单测里面生成的代码: extern "C" __global__ void fused_elementwise_0(int N, float* var0, float* var1, float* var3, float* var5, float* var2, float* var4, float* var6, float* var7, float* var8) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N;
idx += gridDim.x * blockDim.x) {
var2[idx] = var0[idx] * var1[idx];
var4[idx] = var2[idx] + var3[idx];
var6[idx] = var4[idx] - var5[idx];
var7[idx] = real_max(var6[idx], 0);
var8[idx] = 1.0 / (1.0 + real_exp(- var7[idx]));
}
}
extern "C" __global__ void fused_elementwise_grad_0(int N, float* var0, float* var1, float* var2, float* var3, float* var7, float* var4, float* var5, float* var6) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N;
idx += gridDim.x * blockDim.x) {
var6[idx] = var2[idx] > 0 ? var7[idx] : 0;
var4[idx] = var6[idx] * var1[idx];
var5[idx] = var6[idx] * var0[idx];
}
} 目前存在的一些问题:
|
std::unordered_set<std::string> res; | ||
for (auto& t : operations_) { | ||
if ((t.second.type == type) && | ||
(num_operands < 0 || t.second.num_operands == num_operands)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥num_operands 小于0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Find
函数用于获取OperationMap
内指定type、num_operands的op_type。当num_operands < 0
,或者可以是== 0
时,代表获取指定type的所有op_type,包括任意操作数的。
} | ||
|
||
bool IsValid() { | ||
if (!IsGradOp() && exprs.size() != 1U) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里面的expr的size为何不能为1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前的设定是,所有的Operation
,前向计算只有一个输出,可以有多个输入。所以如果是前向算子,则应该只有1个计算公式。这里判断,如果这个算是不是反向算子(是前向算子),且不是1个计算公式,则认为是不合法的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Add the definition of operation in fusion_group. * Use operations in OperationMap to detect fusion_group of elementwise pattern. * Add namespace fusion_group in code_generator. * Use operations recorded in OperationMap to generate code. * Remove implementation codes to .cc file. * Refine Operation and CodeGenerator to make it easier to generate code for grad_op. Refine the unittest for better reuse. * Avoid recording the template's keyword in a array. * Support the generating of code for grad_op and add unittest. test=develop * Remove replaced_element_in_order and use use number instead. test=develop
* Add the definition of operation in fusion_group. * Use operations in OperationMap to detect fusion_group of elementwise pattern. * Add namespace fusion_group in code_generator. * Use operations recorded in OperationMap to generate code. * Remove implementation codes to .cc file. * Refine Operation and CodeGenerator to make it easier to generate code for grad_op. Refine the unittest for better reuse. * Avoid recording the template's keyword in a array. * Support the generating of code for grad_op and add unittest. test=develop * Remove replaced_element_in_order and use use number instead. test=develop
Operation
和OperationMap
的定义,其中定义支持的op类型,以及op计算公式模板。将子图匹配和代码生成中支持的op类型关联起来。