Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] XLA changes for finetuning #110

Merged
merged 17 commits into from
Jun 11, 2023
Merged

Conversation

gkroiz
Copy link
Contributor

@gkroiz gkroiz commented Jun 5, 2023

This PR adds changes specific for finetuning on TPUs.
Used TPU v4-8 with

log_interval = 1
devices = 4
batch_size = 64 / devices
micro_batch_size = 4
gradient_accumulation_steps = 4
num_epochs = 5

for stablelm-base-alpha-3b; resulting finetuning time ~= 4500s.

@gkroiz
Copy link
Contributor Author

gkroiz commented Jun 5, 2023

Would it be simpler to have a separate file for tpu finetuning (for example, adapter_tpu.py)?

@gkroiz
Copy link
Contributor Author

gkroiz commented Jun 5, 2023

cc @Liyang90 for review

@gkroiz
Copy link
Contributor Author

gkroiz commented Jun 5, 2023

TODO: if we decide to keep these changes in adapter.py, we also need to make the same adjustments in adapter_v2.py

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some early questions!

finetune/adapter.py Outdated Show resolved Hide resolved
finetune/adapter.py Outdated Show resolved Hide resolved
finetune/adapter.py Outdated Show resolved Hide resolved
@gkroiz
Copy link
Contributor Author

gkroiz commented Jun 5, 2023

When using TPUs, the current code in adapter.py won't work when using more than 8 cores (4 chips).

For example, let's say we want to run finetune on v4-64 (64 cores, 32 chips). We would need two device counts, one for local devices within each of the 8 workers and one for the total device count. The reasoning here is that the XLA fabric strategy needs to define devices=4 (local device count) but for batch_size, max_iters, and warmup_iters, we would want to define based on the total device count, which would be 32. For my local tests, I made this distinction by defining local_devices and total_devices, but I'm not sure if this leads to unnecessary confusion when using non-xla strategies. What are your thoughts @carmocca?

finetune/adapter.py Outdated Show resolved Hide resolved
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think making these distinctions is fine. Because they are still true in the general multi-node case, not just with XLA.

finetune/adapter.py Outdated Show resolved Hide resolved
finetune/adapter.py Outdated Show resolved Hide resolved
gkroiz and others added 2 commits June 5, 2023 20:09
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
@gkroiz
Copy link
Contributor Author

gkroiz commented Jun 7, 2023

@carmocca I noticed this PR is now closed. Do we no longer need to xla changes?

@carmocca
Copy link
Contributor

carmocca commented Jun 7, 2023

Sorry, it was an accident!

@carmocca carmocca reopened this Jun 7, 2023
@gkroiz gkroiz marked this pull request as ready for review June 7, 2023 21:03
@gkroiz gkroiz requested a review from lantiga as a code owner June 7, 2023 21:03
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment. Rest looks good

finetune/adapter.py Outdated Show resolved Hide resolved
@carmocca carmocca merged commit 7236f51 into Lightning-AI:main Jun 11, 2023
@gkroiz gkroiz deleted the finetuning_tpu branch June 12, 2023 05:09
rasbt added a commit that referenced this pull request Jun 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants