-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PEFT
] Fix save_pretrained
to make sure adapters weights are also saved on TPU
#29388
[PEFT
] Fix save_pretrained
to make sure adapters weights are also saved on TPU
#29388
Conversation
apter weights when using PEFT
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
PEFT
] Fix save_pretrained
to make sure adapters weights are also saved
PEFT
] Fix save_pretrained
to make sure adapters weights are also savedPEFT
] Fix save_pretrained
to make sure adapters weights are also saved on TPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 🤗
src/transformers/trainer.py
Outdated
@@ -3035,9 +3035,10 @@ def _save_tpu(self, output_dir: Optional[str] = None): | |||
|
|||
# Save a trained model and configuration using `save_pretrained()`. | |||
# They can then be reloaded using `from_pretrained()` | |||
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all PushtoHubMixin
can fall in the classes that support save_pretrained and from pretrained so we could also use that as both should inherit from the latter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My fix was inspired from the code here:
transformers/src/transformers/trainer.py
Line 3073 in 0cb946d
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea @ArthurZucker pushed it.
Hello there, What is the state of the PR? Is there a timeline when it will be merged and a new release of transformer will be out? |
Just waiting for @shub-kris to come back ( he is off ) and transformers release will be in around 2 weeks |
I ran some tests on a GKE Cluster with TPU V4 with 4 nodes. https://gist.github.com/moficodes/1492228c80a3c08747a973b519cc7cda This run fails with Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 13, in storage_ptr
return tensor.untyped_storage().data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "//fsdp.py", line 112, in <module>
model.save_pretrained(new_model_id)
File "/usr/local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2448, in save_pretrained
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 281, in save_file
serialize_file(_flatten(tensors), filename, metadata=metadata)
File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 470, in _flatten
shared_pointers = _find_shared_tensors(tensors)
File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 72, in _find_shared_tensors
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 17, in storage_ptr
return tensor.storage().data_ptr()
File "/usr/local/lib/python3.10/site-packages/torch/storage.py", line 956, in data_ptr
return self._data_ptr()
File "/usr/local/lib/python3.10/site-packages/torch/storage.py", line 960, in _data_ptr
return self._untyped_storage.data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage. That looks like the original error. So not certain if the cause of the error was resolved. |
Hi @moficodes thanks for flagging this error, but on an initial glance, it doesn't look like the problem that this PR addresses. This PR aims to save the adapter weights, which were not being saved before this PR. So, if you would have used trainer with this change it would save the adapter-weights too:
Earlier, export XLA_USE_BF16=1 PJRT_DEVICE=TPU XLA_USE_SPMD=1 HF_TOKEN=<your-HF-TOKEN>
python save-gemma.py So, it might happen that the error you are encountering is unrelated to what this PR tries to fix. |
I see. Will open a separate issue for it then. Thank you! |
The error happens on the same line though. |
@LysandreJik can we merge this if it looks good to you, since @ArthurZucker is on holidays and I made the changes he asked and have tested it too locally. |
@moficodes answered it here: #29608 (comment) |
@shub-kris Based on reviews and code, we can merge. There's currently a failing test which needs to be resolved first. Could you try rebasing on main to make sure you have all the latest updates, and trigger a fresh CI run? |
@amyeroberts Than you for looking into the PR. I have rebased but some checks are still failing because of this I guess: https://github.com/huggingface/transformers/runs/22627188153 |
@shub-kris Yep - a fix has just been merged into main. Apologies for the disruption. Could you try rebasing again? |
Thanks a lot @amyeroberts |
… saved on TPU (#29388) * Fix for saving ad apter weights when using PEFT * Change supported-classes to PushToHubMixin
Bug Fix for saving adapter weights when using PEFT
What does this PR do?
This PR fixes saving adapter weights when using PEFT on TPUs. Currently only the model weights are being saved and not the adapter weights.
I tested it locally with this change on this script and now it saves following whiles whenever checkpointing:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Earlier discussed here
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.