diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 45afbecdb30..220883e6b4e 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -10,19 +10,6 @@ import unittest -def onlyIfTorchSupportsCUDA(fn): - return unittest.skipIf( - not torch.cuda.is_available(), reason="requires PyTorch CUDA support")( - fn) - - -def onlyIfPJRTDeviceIsCUDA(fn): - return unittest.skipIf( - os.environ.get("PJRT_DEVICE") not in ("GPU", "CUDA"), - reason="requires CUDA as PJRT_DEVICE")( - fn) - - def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): if isinstance(output1, torch.Tensor): testcase.assertIsInstance(output2, torch.Tensor) @@ -4726,27 +4713,6 @@ def test_aten_where_self_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs) - def _test_move_tensor_cuda_to_xla(self, cpu_tensor): - # Assumes CPU-XLA data movement works. - cuda_tensor = cpu_tensor.to("cuda") - # Move tensor CUDA -> XLA. - xla_tensor = cuda_tensor.to(xm.xla_device()) - # Move the XLA tensor back to CPU, and check that it is the same as - # the original CPU tensor. - self.assertTrue(torch.equal(cpu_tensor, xla_tensor.cpu())) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_aten_move_cuda_to_xla(self): - self._test_move_tensor_cuda_to_xla(torch.arange(5)) - - @onlyIfTorchSupportsCUDA - @onlyIfPJRTDeviceIsCUDA - def test_aten_move_scalar_cuda_to_xla(self): - # 0-dimensional scalar-tensor - # Has a different execution path than other tensors. - self._test_move_tensor_cuda_to_xla(torch.tensor(42)) - if __name__ == '__main__': unittest.main() diff --git a/test/test_operations.py b/test/test_operations.py index e11c54f86bf..e67722dfec2 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -131,6 +131,19 @@ def _zeros_like(tensor_list): return zeros_tensors +def onlyIfTorchSupportsCUDA(fn): + return unittest.skipIf( + not torch.cuda.is_available(), reason="requires PyTorch CUDA support")( + fn) + + +def onlyIfPJRTDeviceIsCUDA(fn): + return unittest.skipIf( + os.environ.get("PJRT_DEVICE") not in ("GPU", "CUDA"), + reason="requires CUDA as PJRT_DEVICE")( + fn) + + class TestToXlaTensorArena(test_utils.XlaTestCase): def test(self): @@ -2356,6 +2369,27 @@ def test_as_strided_input_larger(self): self.assertEqual(a, former_a) + def _test_move_tensor_cuda_to_xla(self, cpu_tensor): + # Assumes CPU-XLA data movement works. + cuda_tensor = cpu_tensor.to("cuda") + # Move tensor CUDA -> XLA. + xla_tensor = cuda_tensor.to(xm.xla_device()) + # Move the XLA tensor back to CPU, and check that it is the same as + # the original CPU tensor. + self.assertTrue(torch.equal(cpu_tensor, xla_tensor.cpu())) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_aten_move_cuda_to_xla(self): + self._test_move_tensor_cuda_to_xla(torch.arange(5)) + + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_aten_move_scalar_cuda_to_xla(self): + # 0-dimensional scalar-tensor + # Has a different execution path than other tensors. + self._test_move_tensor_cuda_to_xla(torch.tensor(42)) + if __name__ == '__main__': torch.set_default_dtype(torch.float32)