Skip to content

Commit

Permalink
Move CUDA tensors to CPU before moving to XLA. (#6060)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Dec 8, 2023
1 parent a7d58f7 commit d5df845
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
26 changes: 26 additions & 0 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,24 @@
import torch
import torch._export

import os
import tempfile
import unittest


def onlyIfTorchSupportsCUDA(fn):
return unittest.skipIf(
not torch.cuda.is_available(), reason="requires PyTorch CUDA support")(
fn)


def onlyIfPJRTDeviceIsCUDA(fn):
return unittest.skipIf(
os.environ.get("PJRT_DEVICE") not in ("GPU", "CUDA"),
reason="requires CUDA as PJRT_DEVICE")(
fn)


def diff_output(testcase, output1, output2, rtol, atol):
if isinstance(output1, torch.Tensor):
testcase.assertIsInstance(output2, torch.Tensor)
Expand Down Expand Up @@ -4664,6 +4678,18 @@ def test_aten_where_self_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs)

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_aten_move_cuda_to_xla(self):
# Assumes CPU-XLA data movement works.
t_cpu = torch.arange(5)
t_cuda = t_cpu.to("cuda")
# Move tensor CUDA -> XLA.
t_xla = t_cuda.to(xm.xla_device())
# Move the XLA tensor back to CPU, and check that it is the same as
# the original CPU tensor.
self.assertTrue(torch.equal(t_cpu, t_xla.cpu()))


if __name__ == '__main__':
unittest.main()
9 changes: 7 additions & 2 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ class AtenSource : public TensorSource {
if (target_torch_type != tensor.type().scalarType()) {
TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1);
}
tensor_ = std::move(tensor.to(target_torch_type, /*non_blocking=*/false,
/*copy=*/true, at::MemoryFormat::Contiguous));
// TODO(ysiraichi): check, first, if tensor lives in a device that the
// current PjRt client has access. If so, we don't need to go through the
// CPU.
tensor_ = std::move(
tensor.to(at::TensorOptions().device(at::kCPU).dtype(target_torch_type),
/*non_blocking=*/false,
/*copy=*/true, at::MemoryFormat::Contiguous));
}

const void* data() const override { return tensor_.const_data_ptr(); }
Expand Down

0 comments on commit d5df845

Please sign in to comment.