-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPS support #1706
MPS support #1706
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1706
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5d45d3c with merge base 4efd7fd ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -325,6 +321,11 @@ def generate( | |||
|
|||
# keeps track at a high level if we've already hit a stop token in a sequence so we can early stop | |||
stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) | |||
stop_tokens = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see below
@@ -325,6 +321,11 @@ def generate( | |||
|
|||
# keeps track at a high level if we've already hit a stop token in a sequence so we can early stop | |||
stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) | |||
stop_tokens = ( | |||
torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why but outputs from the model on my device were int32
. The prompt is int64
and torch.isin
was complaining when stop_tokens.dtype != tokens.dtype
later on. I wasn't ready for further investigation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
carry on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch
for dvc in ["cpu", "cuda"]:
print(dvc)
device = torch.device("cuda")
x = torch.arange(10, device = device, dtype=torch.int64)
y = torch.arange(10, device = device, dtype=torch.int32)
torch.isin(x, y)
So fine on CPU/CUDA
BUT this seems like expected behaviour on MPS? There's actually a test for this in core
https://github.com/pytorch/pytorch/blob/e5a57932f06746a7f6779a97b4e11394b614fa81/test/test_mps.py#L8512
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah there's a check for it in the MPS implementation https://github.com/pytorch/pytorch/blob/e5a57932f06746a7f6779a97b4e11394b614fa81/aten/src/ATen/native/mps/operations/TensorCompare.mm#L297
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe because the underlying MPS call doesn't support different datatypes?
https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/equal(_:_:name:)?language=objc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably it and there's NO WAY I'm loading up XCode to try it out. I hope you are contented.
https://xkcd.com/356/
CONTRIBUTING.md
Outdated
> [!NOTE] | ||
> For contributors who are developing on MPS, you may find that several of our tests fail due to precision differences falling outside of numerical tolerances on our tests. As of 27/09/2024 the following tests will fail for this reason: | ||
<details> | ||
<summary>Failed Tests</summary> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented in Discord, but why not just (a) add a new method in test_utils.py (is_mps
or something like that), (b) decorate each of these test cases with it, and (c) point to an issue tracking all of them in case someone wants to help debug any individual one? For (a) we can also just put in _device.py
, no strong preference here
k_out[:, :, cache_pos] = k_val | ||
v_out[:, :, cache_pos] = v_val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this create an extra copy or anything? (Not a blocker, just curious)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure. I drew inspiration from here https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/cache_utils.py#L1168
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have an actual gate on those failing unit tests with a pointer to a proper issue for now (otherwise it bloats up our CONTRIBUTING.md a lot). Other than that if you've run the generation recipe (yes just that one and not every single recipe), this looks good to me
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1706 +/- ##
==========================================
- Coverage 70.67% 67.61% -3.06%
==========================================
Files 299 304 +5
Lines 15251 15610 +359
==========================================
- Hits 10778 10555 -223
- Misses 4473 5055 +582
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s beautiful
This reverts commit 3fddc56.
Context
What is the purpose of this PR? Is it to
closes #1215
Please link to any issues this PR addresses.
Changelog
Added MPS support. Tested with
generate_v2
,eleuther_eval
,lora_finetune_single_device
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example