Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AutoParallelism】Add 'Eager1 F1B" pipeline strategy #57605

Merged
merged 5 commits into from
Sep 27, 2023

Conversation

heavyrain-lzy
Copy link
Contributor

@heavyrain-lzy heavyrain-lzy commented Sep 21, 2023

PR types

Performance optimization

PR changes

Others

Description

PCard-71568
本PR主要是增加了一种在pp策略下的子图编排策略,用来 overlap pp不同stage之间forward-job的send/recv和计算,提升大模型推训端到端性能。主要工作如下:

  1. 增加_overlap_send_recv函数,用于升级_insert_sync_for_fthenb_1f1b,去除借助c_sync_calc_streamc_sync_comm_stream来同步程序的逻辑,但是该函数可能会增加显存峰值,所以增加enable_send_recv_overlap参数进行控制,并默认关闭。
  2. 增加pipeline_scheduler_Eager1F1B的pipeline pass,在enable_send_recv_overlap==True的前提下,提升大模型执行速度,但是会增加显存峰值。
    • 三个单机8卡测试结果:相对于1F1B三个测试用例性能提升0.5%~1%,device-0显存增加2%~4%,timeline显示recv和calculate已进行overlap;后续增加多机pp策略下的性能收益
    • 多机多卡测试:4机32卡MP8-PP4下两个用例:性能分别提升2%,2.9%,stage-0 device-0显存增加28.51%,29.14%;2机16卡MP8-PP2一个用例:性能提升5.93%,stage-0 device-0显存增加12.06%

综上:在显存不是瓶颈的情况下,开启Eager1F1B能够单机下可获得1%的性能提升,多机下可获得2%~6%的提升(提升的幅度跟send/recv的通信耗时有关),可以根据具体的模型使能该策略。

相关论文On Optimizing the Communication of Model Parallelism

@paddle-bot
Copy link

paddle-bot bot commented Sep 21, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

for op in block.ops:
if op.type == 'send_v2':
op._set_attr("dynamic_shape", False)
op._set_attr("use_calc_stream", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_calc_stream should be false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_calc_stream is false, send_v2 will use the streams inphi::distributed::NCCLCommContext

@heavyrain-lzy heavyrain-lzy merged commit cb3c681 into PaddlePaddle:develop Sep 27, 2023
Frida-a pushed a commit to Frida-a/Paddle that referenced this pull request Oct 14, 2023
* delete sync operation in pipeline strategy

* add Eager1F1B pipeline strategy
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Oct 16, 2023
* delete sync operation in pipeline strategy

* add Eager1F1B pipeline strategy
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* delete sync operation in pipeline strategy

* add Eager1F1B pipeline strategy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants