Skip to content

Commit

Permalink
MPS: isin_mps_friendly can support 0D tensors (huggingface#34538)
Browse files Browse the repository at this point in the history
* apply fix

* tested

* make fixup
  • Loading branch information
gante authored and BernardZach committed Dec 5, 2024
1 parent d0f0293 commit 76e95a2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,17 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
Args:
elements (`torch.Tensor`): Input elements
test_elements (`torch.Tensor`): The elements to check against.
test_elements (`torch.Tensor` or `int`): The elements to check against.
Returns:
`torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
and False otherwise
"""

if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
test_elements = torch.tensor(test_elements)
if test_elements.ndim == 0:
test_elements = test_elements.unsqueeze(0)
return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
else:
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
Expand Down
7 changes: 6 additions & 1 deletion tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,7 +1711,12 @@ def test_isin_mps_friendly(self):
torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
)
)
# We can match against an tensor of integers
# We can match against an 0D tensor
random_test_tensor = torch.randint(0, 100, (1,)).squeeze()
self.assertTrue(
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
)
# We can match against an 1D tensor (with many items)
random_test_tensor = torch.randint(0, 100, (10,))
self.assertTrue(
torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
Expand Down

0 comments on commit 76e95a2

Please sign in to comment.