-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ZeRO3, improved parameter all-gather operation #1188
Conversation
1) Removing the norm computation in debug printing 2) Changing _all_gather to be sync op in fetch_sub_module Reason: the async version is not async at all, because each all_gather calls torch.cuda.synchronize() to guarantee previous communication op to be completed 3) Adding new function _allgather_params_split_launch the existing _allgather_params has explicit memcpy after the all-gather op. We can avoid the explicit memory copy at python side, to improve the performance. Known issue: the `torch.distributed.all_gather` will do implicit memcpy at the end of each `ncclAllgather`.
c848c30
to
1e73e75
Compare
micro benchmark shows the improvement of allgather a transformer layer with 9834560 elements in half precision is about 1.1ms on aws-p4d instance.
Performance improvement of 5.1B bert on aws-p4d: fwd: 300ms -> 200ms bwd: 680ms -> 610ms
@jfc4050 @tjruwase The performance of micro-benchmark shows about 1.1ms time reduction for allgathering a transformer layer with 9.8M params (half-precision) of each partition on a p4d.24xl instance. In end to end training, the forward time could be further reduced to 200ms for 5.1B bing-bert model. (previously forward time around 280ms-300ms) Looking for suggestions. |
hey did you found the reason for failure at #1170 ? I saw that pr has passed the tests. I plan to work on a fix this thursday. it would be nice if you can provide some insights about your fix. Thanks! |
@zarzen, thanks for following up and sorry that I forgot to update you. Yes, I was able to fix the issue. The problem was that 2 zero context objects were constructed along the way and a parameter that was gathered by one context was partitioned in the other. The fix is to avoid multiple zero contexts, but rather to reuse the existing context to register any newly discovered parameter. The core of the fix is here. We are fortunate that the HF unit tests was able to expose this issue. I will take a closer look at your unit test failures as well. |
Are you refer to this commit, a75e46, for fixing the multi-context issue? does that mean I can wait #1170 get merged first? |
The runtime error is throw from check_gpu_tensors at ProcessGroupNCCL.cpp. |
Updates: im able to reproduce the failure at my side, currently working on a fix. |
but it is strange that the ds_tensor haven't been moved to cuda
Hi @tjruwase
does this imply other potential bugs maybe? |
Yes, this is quite concerning actually. It will require further investigation. But I think this is not blocking to merge this PR, correct? |
I think so. |
For
bing_bert
model with following configuration (about 5.1B params), forward computation improved from ~470ms to ~270ms, backward time improved from ~900ms to ~670ms. (hardware setup: 1x ec2-p4d.24xlarge instance)Reason: the async version is not async at all, because each
all_gather calls torch.cuda.synchronize() to guarantee previous
communication op to be completed
_allgather_params_coalesced
the existing _allgather_params has explicit memcpy after the
all-gather op. We can avoid the explicit memory copy at
python side, to improve the performance.
_partition_param
withtorch.empty
function.Notes:
Using most recent updates on PyTorch
_all_gather_base
function could get further performance boost.As the
_all_gather_base
function avoids the redundant memory copy. Refer to pytorch/pytorch#56315