From 210c8675227935a42b00498e18a1fc23072d635e Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 8 Dec 2023 02:37:56 -0300 Subject: [PATCH 1/3] Add test. --- test/test_core_aten_ops.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 68c2d12184f..b155cc41aaa 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -5,10 +5,24 @@ import torch import torch._export +import os import tempfile import unittest +def onlyIfTorchSupportsCUDA(fn): + return unittest.skipIf( + not torch.cuda.is_available(), reason="requires PyTorch CUDA support")( + fn) + + +def onlyIfPJRTDeviceIsCUDA(fn): + return unittest.skipIf( + os.environ.get("PJRT_DEVICE") not in ("GPU", "CUDA"), + reason="requires CUDA as PJRT_DEVICE")( + fn) + + def diff_output(testcase, output1, output2, rtol, atol): if isinstance(output1, torch.Tensor): testcase.assertIsInstance(output2, torch.Tensor) @@ -4664,6 +4678,18 @@ def test_aten_where_self_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs) + @onlyIfTorchSupportsCUDA + @onlyIfPJRTDeviceIsCUDA + def test_aten_move_cuda_to_xla(self): + # Assumes CPU-XLA data movement works. + t_cpu = torch.arange(5) + t_cuda = t_cpu.to("cuda") + # Move tensor CUDA -> XLA. + t_xla = t_cuda.to(xm.xla_device()) + # Move the XLA tensor back to CPU, and check that it is the same as + # the original CPU tensor. + self.assertTrue(torch.equal(t_cpu, t_xla.cpu())) + if __name__ == '__main__': unittest.main() From 4145ce17c0bc39ca49c5681c244516ba27e4ffa0 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 8 Dec 2023 01:57:55 -0300 Subject: [PATCH 2/3] Move to CPU before actually copying. --- torch_xla/csrc/runtime/tensor_source.h | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index ba82c5c8e42..974ee7c7691 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -54,9 +54,13 @@ class AtenSource : public TensorSource { if (target_torch_type != tensor.type().scalarType()) { TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); } - tensor_ = std::move(tensor.to(target_torch_type, /*non_blocking=*/false, - /*copy=*/true, at::MemoryFormat::Contiguous)); - } + // TODO(ysiraichi): check, first, if tensor lives in a device that the current + // PjRt client has access. If so, we don't need to go through the CPU. + tensor_ = std::move( + tensor.to(at::TensorOptions().device(at::kCPU).dtype(target_torch_type), + /*non_blocking=*/false, + /*copy=*/true, at::MemoryFormat::Contiguous)); + } const void* data() const override { return tensor_.const_data_ptr(); } From 2136d2b14f57016cc7a3ff15862790efad2164db Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 8 Dec 2023 04:26:29 -0300 Subject: [PATCH 3/3] Fix lint issues. --- torch_xla/csrc/runtime/tensor_source.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index 974ee7c7691..19665c5627a 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -54,13 +54,14 @@ class AtenSource : public TensorSource { if (target_torch_type != tensor.type().scalarType()) { TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); } - // TODO(ysiraichi): check, first, if tensor lives in a device that the current - // PjRt client has access. If so, we don't need to go through the CPU. + // TODO(ysiraichi): check, first, if tensor lives in a device that the + // current PjRt client has access. If so, we don't need to go through the + // CPU. tensor_ = std::move( tensor.to(at::TensorOptions().device(at::kCPU).dtype(target_torch_type), /*non_blocking=*/false, /*copy=*/true, at::MemoryFormat::Contiguous)); - } + } const void* data() const override { return tensor_.const_data_ptr(); }