diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 68c2d12184fd..b155cc41aaa3 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -5,10 +5,24 @@ import torch import torch._export +import os import tempfile import unittest +def onlyIfTorchSupportsCUDA(fn): + return unittest.skipIf( + not torch.cuda.is_available(), reason="requires PyTorch CUDA support")( + fn) + + +def onlyIfPJRTDeviceIsCUDA(fn): + return unittest.skipIf( + os.environ.get("PJRT_DEVICE") not in ("GPU", "CUDA"), + reason="requires CUDA as PJRT_DEVICE")( + fn) + + def diff_output(testcase, output1, output2, rtol, atol): if isinstance(output1, torch.Tensor): testcase.assertIsInstance(output2, torch.Tensor) @@ -4664,6 +4678,18 @@ def test_aten_where_self_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs) + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_aten_move_cuda_to_xla(self): + # Assumes CPU-XLA data movement works. + t_cpu = torch.arange(5) + t_cuda = t_cpu.to("cuda") + # Move tensor CUDA -> XLA. + t_xla = t_cuda.to(xm.xla_device()) + # Move the XLA tensor back to CPU, and check that it is the same as + # the original CPU tensor. + self.assertTrue(torch.equal(t_cpu, t_xla.cpu())) + if __name__ == '__main__': unittest.main() diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index ba82c5c8e42e..19665c5627a4 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -54,8 +54,13 @@ class AtenSource : public TensorSource { if (target_torch_type != tensor.type().scalarType()) { TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); } - tensor_ = std::move(tensor.to(target_torch_type, /*non_blocking=*/false, - /*copy=*/true, at::MemoryFormat::Contiguous)); + // TODO(ysiraichi): check, first, if tensor lives in a device that the + // current PjRt client has access. If so, we don't need to go through the + // CPU. + tensor_ = std::move( + tensor.to(at::TensorOptions().device(at::kCPU).dtype(target_torch_type), + /*non_blocking=*/false, + /*copy=*/true, at::MemoryFormat::Contiguous)); } const void* data() const override { return tensor_.const_data_ptr(); }