[PIR-Auto-Parallel] [cherry-pick] refactor refined recompute pass in PIR mode #70703
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR Category
Auto Parallel
PR Types
Performance
Description
CP from:#70064 、 #70521
该 PR 是在 recompute pass( #69681 ) 的基础上,实现的 refined recompute,是在recompute layer 中选择一些算子不参与重计算,其开关在代码中的调用如下所示:
在每个layer segment 中,按照计算图拓扑结构匹配
pattern = pre_ops + main_ops + suf_ops
,其中,pre_ops 和 suf_ops 是用于辅助匹配 main_ops 的,对于匹配到的前 num 个 main_ops,在反向时不进行重计算,当 num = -1 时,默认匹配到的 main_ops 全部不进行重计算。同时 pass 也对 segment 的数目进行 assert 断言检测,如果在开启
recompute ( strategy._recompute.enable=1)
,但是在模型代码没有使用到 recompute(layer),则将在 recompute pass 中报错其他:
PaddleNLP 中增加 refined recompute 的测试:PaddlePaddle/PaddleNLP#9679
该实现部分参考了旧 IR 下 refined recompute实现:#58533
PCard-88114