-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ZeRO3, improved parameter all-gather operation (#1188)
* remove norm(), avoid memcpy after allgather 1) Removing the norm computation in debug printing 2) Changing _all_gather to be sync op in fetch_sub_module Reason: the async version is not async at all, because each all_gather calls torch.cuda.synchronize() to guarantee previous communication op to be completed 3) Adding new function _allgather_params_split_launch the existing _allgather_params has explicit memcpy after the all-gather op. We can avoid the explicit memory copy at python side, to improve the performance. Known issue: the `torch.distributed.all_gather` will do implicit memcpy at the end of each `ncclAllgather`. * WIP: wrapped ncclAllgather as customized op in DS micro benchmark shows the improvement of allgather a transformer layer with 9834560 elements in half precision is about 1.1ms on aws-p4d instance. * WIP: integrated into partition_parameters Performance improvement of 5.1B bert on aws-p4d: fwd: 300ms -> 200ms bwd: 680ms -> 610ms * Fix format * cleaned dead code, modified unit test * removed customized c++ extension revert back to use torch distributed API * change torch.ones to torch empty * typo * warn if not cuda tensor for allgather * fix formatting * fix: move ds_tensor to cuda device but it is strange that the ds_tensor haven't been moved to cuda * remove try clause on the path for fetching params Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
- Loading branch information
1 parent
7f5a3ad
commit c0eeb69
Showing
2 changed files
with
107 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters