-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update unwrap from accelerate #29933
Conversation
src/transformers/modeling_utils.py
Outdated
@@ -2306,7 +2306,7 @@ def save_pretrained( | |||
files_timestamps = self._get_files_timestamps(save_directory) | |||
|
|||
# Only save the model itself if we are using distributed training | |||
model_to_save = unwrap_model(self) | |||
model_to_save = unwrap_model(self) if is_accelerate_available() else Accelerator().unwrap_model(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what is the best case here. I don't think we want to force users to install accelerate to save a model. If they are saving after training through trainer or accelerate, they will have accelerate installed.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main concern is #29780, which looks to expand it.
Now that being said, everything here revolves around torch, and transformers
with PyTorch
requires Accelerate, so I think it's a fair assumption we can assume accelerate must be available in the env. @ArthurZucker let me know if you disagree with this.
If we agree, then I propose fully removing the current implementation, and solely relying on the one in Accelerate. What's in the mentioned PR can then also be offloaded to there, as the behaviors between the two differ, and with that PR they will differ even more.
The other alternative is in test_trainer
we include a test that verifies similar behavior between Accelerate's and transformers
model unwraps, so we can flag when they are not up to date
src/transformers/modeling_utils.py
Outdated
@@ -105,7 +105,7 @@ | |||
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() | |||
|
|||
if is_accelerate_available(): | |||
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights | |||
from accelerate import Accelerator, dispatch_model, infer_auto_device_map, init_empty_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just use extract_model_from_parallel
instead of going through the Accelerator
, since that's all it's calling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I thought about that too. Maybe, we can put that in the unwrap_model
so that it is easier to understand. However, we need to add a test to make sure that we have the same behavior.
Got word from tf-boi that generally @amyeroberts @ArthurZucker what do you say? :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current state looks good. But yes torch requires accelerate, but let's keep unwrap anyway, sounds simpler
Hi @muellerzr @SunMarc , does this PR will merged? We need this. |
I will finish this PR asap @zorrofox ! @ArthurZucker , do you want to switch back to |
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me! (after quality
😉 )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM
* Use unwrap with the one in accelerate * oups * update unwrap * fix * wording * raise error instead * comment * doc * Update src/transformers/modeling_utils.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * style * put else --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
* Use unwrap with the one in accelerate * oups * update unwrap * fix * wording * raise error instead * comment * doc * Update src/transformers/modeling_utils.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * style * put else --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
* Use unwrap with the one in accelerate * oups * update unwrap * fix * wording * raise error instead * comment * doc * Update src/transformers/modeling_utils.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * style * put else --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
What does this PR do ?
This PR update the unwrap function to use the one in accelerate instead.
Fixes issue from @abhishekkrthakur