Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS support #1706

Merged
merged 5 commits into from
Sep 28, 2024
Merged

MPS support #1706

merged 5 commits into from
Sep 28, 2024

Conversation

SalmanMohammadi
Copy link
Collaborator

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

closes #1215

Please link to any issues this PR addresses.

Changelog

Added MPS support. Tested with generate_v2, eleuther_eval, lora_finetune_single_device

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1706

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5d45d3c with merge base 4efd7fd (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 27, 2024
@@ -325,6 +321,11 @@ def generate(

# keeps track at a high level if we've already hit a stop token in a sequence so we can early stop
stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device)
stop_tokens = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see below

@@ -325,6 +321,11 @@ def generate(

# keeps track at a high level if we've already hit a stop token in a sequence so we can early stop
stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device)
stop_tokens = (
torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why but outputs from the model on my device were int32. The prompt is int64 and torch.isin was complaining when stop_tokens.dtype != tokens.dtype later on. I wasn't ready for further investigation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

carry on

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
for dvc in ["cpu", "cuda"]:
    print(dvc)
    device = torch.device("cuda")
    x = torch.arange(10, device = device, dtype=torch.int64)
    y = torch.arange(10, device = device, dtype=torch.int32)
    torch.isin(x, y)

So fine on CPU/CUDA
BUT this seems like expected behaviour on MPS? There's actually a test for this in core
https://github.com/pytorch/pytorch/blob/e5a57932f06746a7f6779a97b4e11394b614fa81/test/test_mps.py#L8512

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe because the underlying MPS call doesn't support different datatypes?
https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/equal(_:_:name:)?language=objc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably it and there's NO WAY I'm loading up XCode to try it out. I hope you are contented.
https://xkcd.com/356/

CONTRIBUTING.md Outdated
Comment on lines 58 to 61
> [!NOTE]
> For contributors who are developing on MPS, you may find that several of our tests fail due to precision differences falling outside of numerical tolerances on our tests. As of 27/09/2024 the following tests will fail for this reason:
<details>
<summary>Failed Tests</summary>
Copy link
Contributor

@ebsmothers ebsmothers Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented in Discord, but why not just (a) add a new method in test_utils.py (is_mps or something like that), (b) decorate each of these test cases with it, and (c) point to an issue tracking all of them in case someone wants to help debug any individual one? For (a) we can also just put in _device.py, no strong preference here

Comment on lines +105 to +106
k_out[:, :, cache_pos] = k_val
v_out[:, :, cache_pos] = v_val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this create an extra copy or anything? (Not a blocker, just curious)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have an actual gate on those failing unit tests with a pointer to a proper issue for now (otherwise it bloats up our CONTRIBUTING.md a lot). Other than that if you've run the generation recipe (yes just that one and not every single recipe), this looks good to me

@codecov-commenter
Copy link

codecov-commenter commented Sep 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 67.61%. Comparing base (6bc143f) to head (5d45d3c).
Report is 16 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1706      +/-   ##
==========================================
- Coverage   70.67%   67.61%   -3.06%     
==========================================
  Files         299      304       +5     
  Lines       15251    15610     +359     
==========================================
- Hits        10778    10555     -223     
- Misses       4473     5055     +582     
Flag Coverage Δ
67.61% <100.00%> (-3.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s beautiful

@SalmanMohammadi SalmanMohammadi merged commit 3fddc56 into pytorch:main Sep 28, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the mps branch September 28, 2024 09:33
This was referenced Sep 28, 2024
SalmanMohammadi added a commit to SalmanMohammadi/torchtune that referenced this pull request Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for MPS?
5 participants