Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO3, improved parameter all-gather operation #1188

Merged
merged 39 commits into from
Oct 31, 2021

Conversation

zarzen
Copy link
Contributor

@zarzen zarzen commented Jun 25, 2021

For bing_bert model with following configuration (about 5.1B params), forward computation improved from ~470ms to ~270ms, backward time improved from ~900ms to ~670ms. (hardware setup: 1x ec2-p4d.24xlarge instance)

    "bert_model_config": {
        "vocab_size_or_config_json_file": 32003,
        "hidden_size": 2560,
        "num_hidden_layers": 64,
        "num_attention_heads": 40,
        "intermediate_size": 10240,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "attention_probs_dropout_prob": 0.1,
        "max_position_embeddings": 512,
        "initializer_range": 0.02
    },
  1. Removing the norm computation in debug printing
  2. Changing _all_gather to be sync op in fetch_sub_module
    Reason: the async version is not async at all, because each
    all_gather calls torch.cuda.synchronize() to guarantee previous
    communication op to be completed
  3. Adding new function _allgather_params_coalesced
    the existing _allgather_params has explicit memcpy after the
    all-gather op. We can avoid the explicit memory copy at
    python side, to improve the performance.
  4. changed _partition_param with torch.empty function.

Notes:
Using most recent updates on PyTorch _all_gather_base function could get further performance boost.
As the _all_gather_base function avoids the redundant memory copy. Refer to pytorch/pytorch#56315

1) Removing the norm computation in debug printing
2) Changing _all_gather to be sync op in fetch_sub_module
    Reason: the async version is not async at all, because each
    all_gather calls torch.cuda.synchronize() to guarantee previous
    communication op to be completed
3) Adding new function _allgather_params_split_launch
    the existing _allgather_params has explicit memcpy after the
    all-gather op. We can avoid the explicit memory copy at
    python side, to improve the performance.

Known issue:
    the `torch.distributed.all_gather` will do implicit memcpy
    at the end of each `ncclAllgather`.
@ghost
Copy link

ghost commented Jun 25, 2021

CLA assistant check
All CLA requirements met.

@zarzen zarzen marked this pull request as draft June 25, 2021 17:33
@tjruwase tjruwase mentioned this pull request Jun 30, 2021
micro benchmark shows the improvement of allgather a
transformer layer with 9834560 elements in half precision is about
1.1ms on aws-p4d instance.
Performance improvement of 5.1B bert on aws-p4d:
fwd: 300ms -> 200ms
bwd: 680ms -> 610ms
@zarzen
Copy link
Contributor Author

zarzen commented Jul 1, 2021

@jfc4050 @tjruwase
I have pushed the customized all_gather operation.
The op uses the cuda stream specified by torch.cuda.stream, and it returns a handle with wrapped cuda event. so you can query/synchronize/wait of the communication on the given stream.

The performance of micro-benchmark shows about 1.1ms time reduction for allgathering a transformer layer with 9.8M params (half-precision) of each partition on a p4d.24xl instance.

In end to end training, the forward time could be further reduced to 200ms for 5.1B bing-bert model. (previously forward time around 280ms-300ms)

Looking for suggestions.

@zarzen zarzen marked this pull request as ready for review July 2, 2021 19:11
@zarzen
Copy link
Contributor Author

zarzen commented Oct 12, 2021

@zarzen, thanks for your question. We just added the HF unit tests and it is causing failures on this PR and #1170. I am currently investigating the failure on #1170 and will get to this one afterwards. However, if you have bandwidth you can also look into this. The steps to run the HF tests and repro can be found here.

hey did you found the reason for failure at #1170 ? I saw that pr has passed the tests. I plan to work on a fix this thursday. it would be nice if you can provide some insights about your fix. Thanks!

@tjruwase
Copy link
Contributor

@zarzen, thanks for following up and sorry that I forgot to update you. Yes, I was able to fix the issue. The problem was that 2 zero context objects were constructed along the way and a parameter that was gathered by one context was partitioned in the other. The fix is to avoid multiple zero contexts, but rather to reuse the existing context to register any newly discovered parameter. The core of the fix is here.

We are fortunate that the HF unit tests was able to expose this issue. I will take a closer look at your unit test failures as well.

@zarzen
Copy link
Contributor Author

zarzen commented Oct 14, 2021

@zarzen, thanks for following up and sorry that I forgot to update you. Yes, I was able to fix the issue. The problem was that 2 zero context objects were constructed along the way and a parameter that was gathered by one context was partitioned in the other. The fix is to avoid multiple zero contexts, but rather to reuse the existing context to register any newly discovered parameter. The core of the fix is here.

We are fortunate that the HF unit tests was able to expose this issue. I will take a closer look at your unit test failures as well.

Are you refer to this commit, a75e46, for fixing the multi-context issue?

does that mean I can wait #1170 get merged first?

@zarzen
Copy link
Contributor Author

zarzen commented Oct 15, 2021

The runtime error is throw from check_gpu_tensors at ProcessGroupNCCL.cpp.

@zarzen
Copy link
Contributor Author

zarzen commented Oct 22, 2021

Updates: im able to reproduce the failure at my side, currently working on a fix.

tjruwase and others added 2 commits October 22, 2021 12:23
but it is strange that the ds_tensor haven't been moved to cuda
@zarzen
Copy link
Contributor Author

zarzen commented Oct 22, 2021

Hi @tjruwase
I found the test failure is due to the device of ds_tensor, which is on CPU rather than a CUDA device, which is unexpected. I thought the ds_tensor is guaranteed on CUDA when we call allgather_param.
Current fix is ad-hoc, where i just move the ds_tensor to cuda at here:

local_tensors.append(param.ds_tensor.cuda())

does this imply other potential bugs maybe?

@tjruwase
Copy link
Contributor

Hi @tjruwase I found the test failure is due to the device of ds_tensor, which is on CPU rather than a CUDA device, which is unexpected. I thought the ds_tensor is guaranteed on CUDA when we call allgather_param. Current fix is ad-hoc, where i just move the ds_tensor to cuda at here:

local_tensors.append(param.ds_tensor.cuda())

does this imply other potential bugs maybe?

Yes, this is quite concerning actually. It will require further investigation. But I think this is not blocking to merge this PR, correct?

@zarzen
Copy link
Contributor Author

zarzen commented Oct 27, 2021

Hi @tjruwase I found the test failure is due to the device of ds_tensor, which is on CPU rather than a CUDA device, which is unexpected. I thought the ds_tensor is guaranteed on CUDA when we call allgather_param. Current fix is ad-hoc, where i just move the ds_tensor to cuda at here:

local_tensors.append(param.ds_tensor.cuda())

does this imply other potential bugs maybe?

Yes, this is quite concerning actually. It will require further investigation. But I think this is not blocking to merge this PR, correct?

I think so.

@tjruwase tjruwase enabled auto-merge (squash) October 31, 2021 05:59
@tjruwase tjruwase merged commit c0eeb69 into microsoft:master Oct 31, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants