-
Notifications
You must be signed in to change notification settings - Fork 913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable distributed LoRA training #821
Conversation
Is that possible to do distributed inference as well? |
Possible yes, but getting a nice speedup is more challenging. That's something we're looking at, but don't have an ETA on right now. |
7685977
to
0466b05
Compare
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
0466b05
to
bbdf210
Compare
144e8a0
to
5050765
Compare
5050765
to
ece20f1
Compare
@awni feel free to review and then we can merge. I split the launcher to a different branch. |
loss=mock_default_loss, | ||
iterate_batches=mock_iterate_batches, | ||
) | ||
with swapped_with_identity(mx.distributed, "all_sum"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why do we need this for the test to work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm I see the magicmock thing is messing up the all_sum
.
f"Val loss {val_loss:.3f}, " | ||
f"Val took {val_time:.3f}s", | ||
flush=True, | ||
) | ||
|
||
if training_callback is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably tha tshould go under the rank==0
condition as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah ok, it makes sense. I was thinking in general callbacks should always run and if the callback is about reporting then it can choose to only run on rank=0. But our callbacks here are only about reporting so it makes sense to just run them only on node 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a good point actually. It's more flexible that way. I'm on board leaving it to the user to specify the rank.
f"Trained Tokens {trained_tokens}, " | ||
f"Peak mem {peak_mem:.3f} GB", | ||
flush=True, | ||
) | ||
|
||
if training_callback is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!! Let's 🚢
3cd80c4
to
c5e09a1
Compare
This works perfectly! Great job 👏 |
The updates to
LORA.md
are missing but TL;DR we can now doto train across two nodes (or more really nothing needs to change).