-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic Llama2 Tuning #39
Conversation
This will be then tweaked for TPU/XLA. Original transformers version is 4.40.0, commit 745bbfe.
This should reduce memory consumption, with low performance loss.
Imported as-is, from version 4.39.3.
For RowParallelLinear and ColumnParallelLinear use Linear instead of the dedicated class, to avoid issues with the backward step.
For now it does not seem to work, seems to be related to shapes of key states that are not compatible with the attention calculation after update. We can investigate the reason and propose a better solution later.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Is there any reasons why you choose to use 2D sharding instead of FSDPv2? The latter is integrated in to transformer trainer: https://huggingface.co/blog/gemma-peft#accelerate-with-fsdp-via-spmd-on-tpu |
@alanwaketan yes, the reason is that it did not work out of the box, apparently the trainer was trying to use some API from the XLA that was in experimental and it is not available anymore, so the code was raising an exception. I guess we can fix the FSDP in transformers too, and compare both to see which one performs better. |
Can you point me to the error? All the API the trainer uses should be available in 2.3. |
FSDPv2's intention is to releif the need for all this complicated shardings from regular users. Unless you know you need 2D sharding instead of 1D. |
@alanwaketan sure it makes total sense, I will get back to you with the error |
@alanwaketan Ah btw I was still using torch xla 2.2, we haven't moved to 2.3 yet. I will do that so I can re-test FSDP. |
@alanwaketan you can check the errors I see with FSDP here. |
This part uses the multihost environment, tested on v5e litepod16.
I raised an issue on the |
2.2 is expected to not work. The 2.3 issue seems more generic than FSDPv2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 🔥
What does this PR do?
Add support for Llama model tuning on TPU v5e, with a related example.
Before submitting