From 4391007333231d5587bcbc7e53adbc369698e08f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 31 Oct 2024 15:22:20 +0000 Subject: [PATCH 1/3] apply fix --- src/transformers/pytorch_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index f3663c09902f52..1f68a2bcc02004 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -322,6 +322,8 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) """ if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: + if test_elements.ndim == 0: + test_elements = test_elements.unsqueeze(0) return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze() else: # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045 From 301f20cb23b9f628a9d2f47536220d8b412271dd Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 31 Oct 2024 15:30:28 +0000 Subject: [PATCH 2/3] tested --- src/transformers/pytorch_utils.py | 3 ++- tests/utils/test_modeling_utils.py | 9 ++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 1f68a2bcc02004..a808f2cb63e861 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -314,7 +314,7 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) Args: elements (`torch.Tensor`): Input elements - test_elements (`torch.Tensor`): The elements to check against. + test_elements (`torch.Tensor` or `int`): The elements to check against. Returns: `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements` @@ -322,6 +322,7 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) """ if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4: + test_elements = torch.tensor(test_elements) if test_elements.ndim == 0: test_elements = test_elements.unsqueeze(0) return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze() diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 0452a10d5d57e6..24c1a3bbc76578 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1711,7 +1711,14 @@ def test_isin_mps_friendly(self): torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer) ) ) - # We can match against an tensor of integers + # We can match against an 0D tensor + random_test_tensor = torch.randint(0, 100, (1,)).squeeze() + self.assertTrue( + torch.equal( + torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor) + ) + ) + # We can match against an 1D tensor (with many items) random_test_tensor = torch.randint(0, 100, (10,)) self.assertTrue( torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) From 06d7e18c21c102d6af77a471385939e89a38497c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 31 Oct 2024 15:32:17 +0000 Subject: [PATCH 3/3] make fixup --- tests/utils/test_modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 24c1a3bbc76578..5fd6251224c3ed 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1714,9 +1714,7 @@ def test_isin_mps_friendly(self): # We can match against an 0D tensor random_test_tensor = torch.randint(0, 100, (1,)).squeeze() self.assertTrue( - torch.equal( - torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor) - ) + torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) ) # We can match against an 1D tensor (with many items) random_test_tensor = torch.randint(0, 100, (10,))