Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 support) #1453

Merged
merged 91 commits into from
Jan 21, 2022

Conversation

jfc4050
Copy link
Contributor

@jfc4050 jfc4050 commented Oct 12, 2021

the addition of bfloat16 support depends on another change that is in PR (#1398). for now including the commit in here. will rebase once that PR goes through

Results

results on 8 p4d cluster (64 A100 GPUs)
Roberta-10B: 54.86 -> 92.07 TFLOPS/GPU
Roberta-50B: 67.61 -> 116.06 TFLOPS/GPU

results on 32 p4d cluster (256 A100 GPUs)
Roberta-50B: ??? (TODO. due to hardware availability) -> 108 TFLOPS/GPU

note that these TFLOPS/GPU calculations ignore gradient checkpointing (which we do layer-wise), to account for it you can multiply by 4/3. These numbers are roughly at parity with ZeRO stage 2, but we unfortunately haven't quite gotten to linear weak scaling on our (worse than yours) hardware.

some notes about the conditions used to get the above numbers:

  • p4d nodes w/ 8 A100 GPUs each and 400Gbps inter-node bandwidth
  • batch size/GPU of 16
  • sequence length 512
  • 52B parameter Roberta model directly from huggingface, not using any custom fused kernels. vocab size ~53k
  • activations are checkpointed layer-wise
  • using bfloat16, we've observed similar throughput as fp16
  • ZeRO Stage3, not using any CPU/disk offload, and contiguous gradients enabled.

Testing

  • 3 day training run using a BERT based model to verify equivalent loss curve as an unpartitioned model
  • added several unit tests to verify activations/gradients

Changelist

Optimizations:

  • batching allgather calls to amortize fixed overhead and improve
    bandwidth utilization (credit: @zarzen)
  • batching reduce_scatter calls to amortize fixed overhead and
    improve bandwidth utilization
  • using *_base variants of allgather and reduce scatter to reduce memory
    allocations and data movement (credit: @zarzen)
  • more fine grained synchronization for communication that allows
    blocking on less work
  • asynchronous parameter deallocations (remains correct due to record_stream calls)
  • precomputation of fetching code/caching of fetching+partitioning decisions - using a fetch queue rather than deciding what to (pre)fetch at each iteration
  • prefetch hook fix (credit: @tjruwase)
  • more aggressive initialization-time defragmentation of parameter partitions
  • optimization of grad-norm computation - moving some math operations out of for loops and delaying synchronization
  • limiting queued collectives to reduce memory pressure on pytorch cuda caching allocator (not elegant solution, real solution is perhaps best handled by pytorch)
  • removing debug logs that were including the result of tensor operations, these were causing host-device transfers and synchronization
  • [for optimizer CPU offload] made some host-device tensor copies async

New features:

  • bfloat16 support for ZeRO 3 (based off changes in Bfloat16 zero2 #1398)
  • there were some features missing for reduce_scatter that were available for allreduce, added them in so they are both at parity

QOL improvements:

  • speed up model initialization by moving weights to GPU before weight
    initialization. Before was taking ~20 minutes for large models, now takes a minute or two.
  • relative unit test imports so that unit tests can be run from any directory
  • change performance logging to include memory consumption
  • add logging w/ model size when done partitioning model
  • added tracking for pytorch CUDA caching allocator flushes (bad for performance) and a warning log for when this is happening
  • added NVTX ranges to high impact operations for easier profiling

Bug fixes:

  • fix init context method when parent modules modify child weights. In cases where this was happening the parent module's modifications would not take effect on child module's weights since they were already partitioned, causing discrepancy in results when the context manager was used
  • noticed that on master with our model we were getting different loss curves than unpartitioned model (was getting several unexpected gradient overflows), verified that this has been fixed
  • parameter prefetching fix (stolen from @tjruwase)

raamjad and others added 2 commits October 12, 2021 14:11
optimizations for stage3:
- prefetching improvements
- batching allgather calls to amortize fixed overhead and improve
  bandwidth utilization
- batching reduce_scatter calls to amortize fixed overhead and
  improve bandwidth utilization
- using *_base variants of allgather and reduce scatter to reduce memory
  allocations and data movement
- more fine grained synchronization for communication that allows
  blocking on less work
- precomputation of fetching code - using a fetch queue rather than
  deciding what to (pre)fetch at each iteration
- limiting queued coalesced communication ops to reduce memory pressure
  on pytorch cuda caching allocator (not elegant solution)

optimizations for stage3-offload:
- made some host-device tensor copies async to improve performance

bug fixes and qol improvements:
- fix init context method when parent modules modify child weights
- speed up model initialization by moving model to GPU before weight
  initialization
- fixed unit test imports so that unit tests can be run from any
  directory
- change performance logging to include memory consumption
- add logging w/ model size when done partitioning model

new features
- bfloat16 support for ZeRO 3
@ghost
Copy link

ghost commented Oct 12, 2021

CLA assistant check
All CLA requirements met.

@tjruwase
Copy link
Contributor

@jfc4050, Wow!!! This is a massive accomplishment. Thanks doing this.

Can you please address the the formatting checks failure by following these instructions?

@tjruwase
Copy link
Contributor

@jfc4050, I am also seeing this strange unit test failure
image

@tjruwase
Copy link
Contributor

@samyam, @stas00, and @jeffra FYI

@jfc4050
Copy link
Contributor Author

jfc4050 commented Oct 12, 2021

@jfc4050, Wow!!! This is a massive accomplishment. Thanks doing this.

Can you please address the the formatting checks failure by following these instructions?

thanks! and will do

@jfc4050, I am also seeing this strange unit test failure image

oh i think it needs to be from .common import distributed_test i made the imports relative so they can be run from any directory. will make that fix

@tjruwase
Copy link
Contributor

@jfc4050, is it possible to separate the bfloat support from this PR? To my understanding, these are quite independent changes, right?

"enabled": true
}
'''
BFLOAT16 = "bfloat16"
Copy link
Collaborator

@stas00 stas00 Oct 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: bfloat16 - I proposed some month back that we don't create too many config entries, but instead switch to using a new dtype block, where the user can flip from bf16 to fp16 to fp32

Especially since they are mutually-exclusive.

But that discussion wasn't concluded. Now it's a good time to bring it back

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds reasonable. note that this change will be published as part of a separate PR, i just made my changes on top of it so it ended up in here. will rebase once that PR makes it in

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please note that I'm just a contributor, so my suggestions are just that - suggestions. Therefore in order not to waste your time, please first secure an agreement from the Deepspeed team when it comes to changing APIs.

In particular this one as it'd require back-compat code to support the current config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raamjad, in context of #1398, if you could add this dtype block, perhaps as a follow up PR, that would be great. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the current coverage of bf16 is very limited, I thought it may be better to do this refactoring of config later once there is higher coverage.
Do you prefer that this be done now?
@stas00 Can you point me to your comments/where you suggested about the dtype block so I know what shape of config changes were suggested

Copy link
Collaborator

@stas00 stas00 Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of my proposals was very simple: it's adding a new top-level config dtype and dropping enabled from fp16 config, and adding bf16 block.

{
    "dtype": "bf16",
    "fp16": {
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        [...]
    },    
}

This approach allows users to keep nuanced settings for each dtype in the same config, but dtype will enable one over the others, so one can easily switch the settings in only one place.

It hasn't been approved by the Deepspeed team (as in no decision has been made about it).

@tjruwase, if you have access to the Teams log from Apr-30 this is when we discussed this. But it won't show it me - the search only shows a snippet. search for 'dtype mockup'.


the other proposal is to have a single dtype block that will take over the fp16 block. This would be useful if many of the config options of bf16 and fp16 overlap. So:

{
    "dtype": {
        "enabled": "fp16",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
}

and for bf16:

{
    "dtype": {
        "enabled": "bf16",
        [...]
    },
}

Copy link
Collaborator

@stas00 stas00 Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a maintainer of DS integration in HF Transformers, I find that it's the easiest when users can download ready-to-use config files, so in my experience having all the config sections already predefined in the config file makes it easier for the user. So the first approach would be preferable for that particular use case.

But of course there can be other ways...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00, thanks for reviving your awesome proposals. Sorry, mere mortals such as we can barely keep up with your genius :).

@stas00, @raamjad is it okay, if we continue this chat on #1398? I will add link to this thread and also post the referenced teams chat. Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're too kind, @tjruwase - it's easy to come up with ideas, it's far from easy to make them a reality. So that's where your genius comes in ;)

Thank you for find that old discussion and repasting it here, as MSFT Teams won't let me access it.

@stas00
Copy link
Collaborator

stas00 commented Oct 12, 2021

Amazing!

@jfc4050
Copy link
Contributor Author

jfc4050 commented Oct 12, 2021

@jfc4050, is it possible to separate the bfloat support from this PR? To my understanding, these are quite independent changes, right?

sure will be some effort to detangle from the partitioning initializer fixes but i think we can do that. Note that once this PR #1398 goes in and we rebase, the bfloat16 specific changes in this PR will be quite small

seems like all the test failures are due to _all_gather_base and torch.cuda.nvtx.range not being present in the older pytorch version used in the tests. can we consider a pytorch version bump or should we go back and make them backwards compatible? there will be some performance penalty associated with using all_gather over _all_gather_base if users decide to use an older pytorch version that doesnt support it

@tjruwase
Copy link
Contributor

@jfc4050, is it possible to separate the bfloat support from this PR? To my understanding, these are quite independent changes, right?

sure will be some effort to detangle from the partitioning initializer fixes but i think we can do that. Note that once this PR #1398 goes in and we rebase, the bfloat16 specific changes in this PR will be quite small

seems like all the test failures are due to _all_gather_base and torch.cuda.nvtx.range not being present in the older pytorch version used in the tests. can we consider a pytorch version bump or should we go back and make them backwards compatible? there will be some performance penalty associated with using all_gather over _all_gather_base if users decide to use an older pytorch version that doesnt support it

Got it. Is it possible to support both here? Even if we bumped version on our CI, it would still break backwards-compatibility for existing users. We could document the perf benefits of the a version. It might also be worth further cluttering the output by issuing a (one-time) warning for this. What do you think?

@jfc4050
Copy link
Contributor Author

jfc4050 commented Oct 13, 2021

@jfc4050, is it possible to separate the bfloat support from this PR? To my understanding, these are quite independent changes, right?

sure will be some effort to detangle from the partitioning initializer fixes but i think we can do that. Note that once this PR #1398 goes in and we rebase, the bfloat16 specific changes in this PR will be quite small
seems like all the test failures are due to _all_gather_base and torch.cuda.nvtx.range not being present in the older pytorch version used in the tests. can we consider a pytorch version bump or should we go back and make them backwards compatible? there will be some performance penalty associated with using all_gather over _all_gather_base if users decide to use an older pytorch version that doesnt support it

Got it. Is it possible to support both here? Even if we bumped version on our CI, it would still break backwards-compatibility for existing users. We could document the perf benefits of the a version. It might also be worth further cluttering the output by issuing a (one-time) warning for this. What do you think?

yep that sounds reasonable. will make that change soon

@jfc4050
Copy link
Contributor Author

jfc4050 commented Jan 12, 2022

hey @tjruwase, any progress on reviews?

Happy New Year! Apologies for the delay, but the review is progress. I think I should finish today. I am leaving comments/questions as I go along stage3.py, which is the last file.

In the meantime, can you please address the comments from @stas00: #1453 (comment)?

Thanks!

Oops that was rude of me :) happy new year! thanks for all the help so far, i know this is a large PR.

i'll start working my way through yours and @stas00's comments

@jfc4050
Copy link
Contributor Author

jfc4050 commented Jan 12, 2022

I will repeat my admiration of this incredible engineering work and contribution. Thanks so much.

Overall, the code looks good to me. I left a few minor comments and there are the two comments by @stas00 to take care of for the final approval.

Thanks! that's very nice of you to say. DeepSpeed has been hugely useful for us, glad to be able to give back.

Working my way through the comments now. as we discussed earlier, after that will do a final performance test to make sure we haven't lost any gains during the PR.

@jfc4050
Copy link
Contributor Author

jfc4050 commented Jan 19, 2022

@tjruwase @stas00 i think the outstanding comments have been addressed now, lmk if there's any other changes we need. otherwise, I'll start trying to get some hardware to rerun performance tests

@stas00
Copy link
Collaborator

stas00 commented Jan 19, 2022

@jfc4050, all my requests have been addressed - thank you! and I have already updated the HF/DS integration huggingface/transformers#14569 and all the many tests pass. All my tests are functional though, none are qualitative as I rely on deespeed to have those.

@tjruwase
Copy link
Contributor

@jfc4050, I will merge latest Friday just to give some space to the PR-MoE release.

@jfc4050
Copy link
Contributor Author

jfc4050 commented Jan 20, 2022

@tjruwase here's results from rerun. they actually got a little better, i'm not sure if its because i've upgraded pytorch/HF or anything - i forgot to record the exact library versions i used to get the performance results reported originally.

on 8 p4d cluster (64 A100 GPUs)
Roberta-10B: 59.23 -> 111.31 TFLOPS/GPU
Roberta-50B: 71.38 -> 127.85 TFLOPS/GPU


also i took a look at the failing tests.

FAILED unit/test_zero_tiled.py::test_tiled_backward[23-29-2-2-False]
FAILED unit/test_zero_tiled.py::test_tiled_returnbias_backward[23-29-2-2-False]
FAILED unit/test_zero_tiled.py::test_tiled_forward[23-29-2-2-False]

looks like they are for the TiledLinear module, and the outputs are off by a small amount - maybe a numerical issue/flaky test. in any case i dont think they are related to any changes made here

@tjruwase
Copy link
Contributor

@jfc4050, thanks for sharing the updated perf results. They look great.

Yes, those unit tests are incredibly flaky. Not to worry about them. Everything looks good for merge Friday.

@tjruwase tjruwase merged commit 4912e0a into microsoft:master Jan 21, 2022
@stas00
Copy link
Collaborator

stas00 commented Jan 21, 2022

Wow! Amazing work all who participated!

Please let me know when a new DS release is made and we will roll out the HF/DS integration of bf16 in HF Transformers.

@@ -116,7 +116,8 @@

# gathers params for saving a model - inefficient but is required in certain situations

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could someone kindly help explain which situations were required by this 16bit parameters gathering (infeeicient) feature, given that there is zero_to_fp32.py script which can help save the parameters? Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for those who can't be bothered with running zero_to_fp32.py and want the 16-bit model extracted on the fly - which is fine for tiny to small models but very slow for large models.

It's also the default in the HF Trainer integration of Deepspeed to make it easy for users to start and have things work transparently. But the documentation explains how to improve upon this default.
https://huggingface.co/docs/transformers/main/main_classes/deepspeed#getting-the-model-weights-out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants