-
Notifications
You must be signed in to change notification settings - Fork 212
Add finetuning strategies for DeepSpeed #1377
Add finetuning strategies for DeepSpeed #1377
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1377 +/- ##
==========================================
- Coverage 92.90% 92.88% -0.02%
==========================================
Files 286 286
Lines 12874 12891 +17
==========================================
+ Hits 11960 11974 +14
- Misses 914 917 +3
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
I don't know how to fix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @ar90n - Thank you so much for working on this PR. I really appreciate that you added the documentation and relevant tests. 🎉
LGTM (just added a couple of minor suggestions)!
Regarding the Example
failures, please don't worry about them. Though the CI has been fixed now, so it should be all good.
@pytest.mark.parametrize( | ||
"strategy_key, strategy_metadata", | ||
[ | ||
("no_freeze", None), | ||
("freeze", None), | ||
("freeze_unfreeze", 2), | ||
("unfreeze_milestones", ((5, 10), 15)), | ||
], | ||
) | ||
def test_deepspeed_finetuning_strategy_key(strategy_key, strategy_metadata): | ||
deepspeed_strategy_key = f"{strategy_key}_deepspeed" | ||
|
||
strategy = _FINETUNING_STRATEGIES_REGISTRY.get(key=strategy_key)(strategy_metadata=strategy_metadata).strategy | ||
deepspeed_strategy = _FINETUNING_STRATEGIES_REGISTRY.get(key=deepspeed_strategy_key)( | ||
strategy_metadata=strategy_metadata | ||
).strategy | ||
assert strategy == deepspeed_strategy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice tests! Thanks for adding them.
Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
for more information, see https://pre-commit.ci
Hi @krshrimali |
I checked the reason for
I'm not familiar with them. It seems that the test process was terminated by its timeout. |
Hi, @ar90n - Please don't worry about it. I'll have to check once on my personal GPU, sometimes these failures can be flaky (because of resources not being available or anything else). I'll merge it once I'm done testing, but this is good to go! 🎉 |
The test passes locally on my GPU, let's merge this and monitor the CI. In case an alarm is raised, I'll attempt to fix it. Thanks, @ar90n for your hard-work and patience with this PR. 🎉 |
Just added the CHANGELOG entry, let's wait for the CI, and push it ASAP. <3 |
@ar90n - FYI, it took us some time to fix the CI, sorry for that. @ethanwharris is currently OOO for this week, so whenever he is back, he'll help merge this. 🎉 Thank you for your contribution, and patience. |
What does this PR do?
This PR provides some workarounds to use DeepSpeed in finetuning. In fact, DeepSpeed cannot work with pytorch-lightning completely because its parameter loading and storing don't work. So this PR added some fine-tuning strategies whose parameter loading and storing are omitted.
Fixes #1249
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃