Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Add support for OffloadModel to enable training large models on 1 GPU. #432

Merged
merged 57 commits into from
Feb 26, 2021

Conversation

anj-s
Copy link
Contributor

@anj-s anj-s commented Feb 24, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • [ X] Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • [X ] Did you write any new necessary tests?

What does this PR do?

Add experimental support for using the OffloadModel API which enables training large models on a single GPU. OffloadModel chunks the given model into a list of modules and copies a given chunk from CPU->GPU during the FW pass. After FW computation the chunk is copied back to the CPU. The process is repeated for the BW pass. The current implementations supports:

  • Specifying number of slices that you want to chunk your model into.
  • Support for activation checkpointing.
  • Support for running multiple microbatches at one time to offset latency due to multiple param copies fro CPU<->GPU.

Caveats:

  • This initial implementation only supports nn.Sequential models.
  • The throughput of the model is smaller than when running without Offload. We will be continue to work on improving performance and suggest configurations that will enable the highest throughput.

References:

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

blefaudeux and others added 30 commits December 29, 2020 16:57
…s 2 --optim_type oss_offload_ddp --batch_size=32 --model vit_large_patch16_224
* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* avoid saving inputs

* fix lint errors

Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>
* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* add support for fp16

* add unit tests

* fix lint errors

* fix test failure

Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>
)

* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* add support for fp16

* add unit tests

* fix lint errors

* fix test failure

* cp work, incorrect output dimensions still need to be fixed

* fixed activation outputs

* intermediate cp of work

* add tests

* fix lint errors

Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 24, 2021
@min-xu-ai
Copy link
Contributor

quick question on the test file location. Should it be

tests/nn/experimental/test_offload.py

or

tests/experimental/nn/test_offload.py?

I think we mirror the dirs. File names can be shorten, like we have test_fsdp*.py but all in the same mirrored dir. That seems like a good convention?

@min-xu-ai
Copy link
Contributor

also, see this comment: Lightning-AI/pytorch-lightning#6152 (comment)

@anj-s
Copy link
Contributor Author

anj-s commented Feb 25, 2021

quick question on the test file location. Should it be

tests/nn/experimental/test_offload.py

or

tests/experimental/nn/test_offload.py?

I think we mirror the dirs. File names can be shorten, like we have test_fsdp*.py but all in the same mirrored dir. That seems like a good convention?

I agree. I want it to be in experimental/ just like I moved tests for ampnet.


def __init__(
self,
model_cpu: nn.Sequential, # hard pre-requisite for now, easier model slicing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussing elsewhere, but I think that the FSDP way (wrap submodules) could apply here, and why not keeping the two options (either one monolithic nn.Sequential call, or a per-module wrap) open ? I think that it adds a lot of flexibility and could be good enough in practice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

practically speaking this means that https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=forward%20hook#torch.nn.Module.register_forward_pre_hook can be used, but the latency will be pretty terrible if used "naively" (wait for the FW wavefront to touch, pull in the module), so it's not really a silver bullet

@blefaudeux
Copy link
Contributor

Thanks for the great PR @anj-s, it's super comprehensive ! I think that we can try to make it more generic over time, it does not have to be perfect right now and it's a very solid basis I believe. Minor nits if you don't mind and curious to have @min-xu-ai eyes on that

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late to the party. I agree with Ben that this gives us a good start. Lots of interesting things we can do potentially with this.

benchmarks/experimental/offload.py Show resolved Hide resolved
fairscale/experimental/nn/offload.py Show resolved Hide resolved
fairscale/experimental/nn/offload.py Outdated Show resolved Hide resolved
fairscale/experimental/nn/offload.py Show resolved Hide resolved
@anj-s anj-s merged commit f7813d6 into master Feb 26, 2021
@anj-s anj-s deleted the offload_experimental branch February 26, 2021 01:09
@ibro45 ibro45 mentioned this pull request Mar 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants