Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Fix bernoulli for int types #100946

Closed
wants to merge 2 commits into from
Closed

Conversation

malfet
Copy link
Contributor

@malfet malfet commented May 9, 2023

🤖 Generated by Copilot at 069fd23

This pull request enhances the MPS implementation of random operations in Distributions.mm and adds more dtype tests for the bernoulli distribution in test_mps.py. This improves the performance, correctness, and usability of the MPS backend for PyTorch.

Fixes #100717

@malfet malfet requested a review from kulinseth as a code owner May 9, 2023 04:49
@pytorch-bot
Copy link

pytorch-bot bot commented May 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100946

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 93a2b00:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels May 9, 2023
@malfet malfet added the topic: bug fixes topic category label May 9, 2023
test/test_mps.py Outdated
@@ -7218,6 +7218,13 @@ def test_bernoulli(self):
mps_out = torch.bernoulli(all_ones)
self.assertEqual(mps_out, all_ones)

# Check it works for different dtypes
for dtype in [torch.float16, torch.int8, torch.int32, torch.int16, torch.int64]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks good.
Can we please add test for random op as well. It has fixes for Bernoulli and random

@malfet malfet force-pushed the malfet/fix-bernoully-for-int-types branch from 069fd23 to c7086d0 Compare May 10, 2023 15:12
@malfet malfet force-pushed the malfet/fix-bernoully-for-int-types branch from c7086d0 to 93a2b00 Compare May 11, 2023 13:50
@malfet malfet added ciflow/mps Run MPS tests (subset of trunk) and removed ciflow/mps Run MPS tests (subset of trunk) labels May 11, 2023
@malfet
Copy link
Contributor Author

malfet commented May 11, 2023

@pytorchbot merge -f "Lint and MPS tests are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) are pending/not yet run. The first few are:

  • EasyCLA

Dig deeper by viewing the pending checks on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@malfet
Copy link
Contributor Author

malfet commented May 11, 2023

/easycla

@malfet
Copy link
Contributor Author

malfet commented May 11, 2023

@pytorchbot merge -f "Easy CLA, Lint and MPS tests are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet malfet deleted the malfet/fix-bernoully-for-int-types branch May 12, 2023 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inconsistent casting of types for 'cpu' and 'mps' devices
3 participants