Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generating code for grad_op #21066

Merged
merged 10 commits into from
Nov 11, 2019

Conversation

Xreki
Copy link
Contributor

@Xreki Xreki commented Nov 7, 2019

  • 增加OperationOperationMap的定义,其中定义支持的op类型,以及op计算公式模板。将子图匹配和代码生成中支持的op类型关联起来。
  • 调整了一下代码生成的代码,可以支持生成grad_op的计算代码。
  • 调整代码生成单测的代码,添加生成grad_op代码的单测。

@Xreki
Copy link
Contributor Author

Xreki commented Nov 7, 2019

单测里面生成的代码:

extern "C" __global__ void fused_elementwise_0(int N, float* var0, float* var1, float* var3, float* var5, float* var2, float* var4, float* var6, float* var7, float* var8) {
  for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
        idx < N;
     idx += gridDim.x * blockDim.x) {
     var2[idx] = var0[idx] * var1[idx];
     var4[idx] = var2[idx] + var3[idx];
     var6[idx] = var4[idx] - var5[idx];
     var7[idx] = real_max(var6[idx], 0);
     var8[idx] = 1.0 / (1.0 + real_exp(- var7[idx]));
   }
}

extern "C" __global__ void fused_elementwise_grad_0(int N, float* var0, float* var1, float* var2, float* var3, float* var7, float* var4, float* var5, float* var6) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     var6[idx] = var2[idx] > 0 ? var7[idx] : 0;
     var4[idx] = var6[idx] * var1[idx];
     var5[idx] = var6[idx] * var0[idx];
   }
}

目前存在的一些问题:

  1. fused_elementwise_0中,一些中间计算结果,需要在后面使用到,应该用临时变量保存起来,避免先写入显存,后面再从显存中读取出来,比如var2[idx]var4[idx]var6[idx]var7[idx]
  2. 实际融合的计算中,一些中间结果可能是不需要写入显存的,即我们只需要拿到var8[idx],因此var2[idx]var4[idx]var6[idx]var7[idx]这些都是中间计算结果,都不需要写入显存。
  3. 在梯度计算中,有些grad_op计算需要用到输入xyout,而有些不需要用到。在定义每个operation的计算模板的时候,为了明确每个参数的函数,所有输入都考虑进去了。在生成代码时,需要能够过滤掉实际不需要访问到的变量,比如fused_elementwise_grad_0中的float* var3

std::unordered_set<std::string> res;
for (auto& t : operations_) {
if ((t.second.type == type) &&
(num_operands < 0 || t.second.num_operands == num_operands)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥num_operands 小于0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Find函数用于获取OperationMap内指定type、num_operands的op_type。当num_operands < 0,或者可以是== 0时,代表获取指定type的所有op_type,包括任意操作数的。

}

bool IsValid() {
if (!IsGradOp() && exprs.size() != 1U) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面的expr的size为何不能为1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前的设定是,所有的Operation,前向计算只有一个输出,可以有多个输入。所以如果是前向算子,则应该只有1个计算公式。这里判断,如果这个算是不是反向算子(是前向算子),且不是1个计算公式,则认为是不合法的。

Copy link
Contributor

@wangchaochaohu wangchaochaohu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Xreki Xreki merged commit 9091f8c into PaddlePaddle:develop Nov 11, 2019
@Xreki Xreki deleted the pass_generate_grad_code branch November 18, 2019 08:52
seiriosPlus pushed a commit to seiriosPlus/Paddle that referenced this pull request Dec 9, 2019
* Add the definition of operation in fusion_group.

* Use operations in OperationMap to detect fusion_group of elementwise pattern.

* Add namespace fusion_group in code_generator.

* Use operations recorded in OperationMap to generate code.

* Remove implementation codes to .cc file.

* Refine Operation and CodeGenerator to make it easier to generate code for grad_op.
Refine the unittest for better reuse.

* Avoid recording the template's keyword in a array.

* Support the generating of code for grad_op and add unittest.
test=develop

* Remove replaced_element_in_order and use use number instead.
test=develop
seiriosPlus pushed a commit to seiriosPlus/Paddle that referenced this pull request Dec 9, 2019
* Add the definition of operation in fusion_group.

* Use operations in OperationMap to detect fusion_group of elementwise pattern.

* Add namespace fusion_group in code_generator.

* Use operations recorded in OperationMap to generate code.

* Remove implementation codes to .cc file.

* Refine Operation and CodeGenerator to make it easier to generate code for grad_op.
Refine the unittest for better reuse.

* Avoid recording the template's keyword in a array.

* Support the generating of code for grad_op and add unittest.
test=develop

* Remove replaced_element_in_order and use use number instead.
test=develop
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants