Skip to content

Commit

Permalink
Move 0-dimensional tensors to CPU before copying to XLA. (#6071)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Dec 9, 2023
1 parent 0857f2a commit a80c1e7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
23 changes: 16 additions & 7 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4678,17 +4678,26 @@ def test_aten_where_self_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.where.self, args, kwargs)

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_aten_move_cuda_to_xla(self):
def _test_move_tensor_cuda_to_xla(self, cpu_tensor):
# Assumes CPU-XLA data movement works.
t_cpu = torch.arange(5)
t_cuda = t_cpu.to("cuda")
cuda_tensor = cpu_tensor.to("cuda")
# Move tensor CUDA -> XLA.
t_xla = t_cuda.to(xm.xla_device())
xla_tensor = cuda_tensor.to(xm.xla_device())
# Move the XLA tensor back to CPU, and check that it is the same as
# the original CPU tensor.
self.assertTrue(torch.equal(t_cpu, t_xla.cpu()))
self.assertTrue(torch.equal(cpu_tensor, xla_tensor.cpu()))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_aten_move_cuda_to_xla(self):
self._test_move_tensor_cuda_to_xla(torch.arange(5))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_aten_move_scalar_cuda_to_xla(self):
# 0-dimensional scalar-tensor
# Has a different execution path than other tensors.
self._test_move_tensor_cuda_to_xla(torch.tensor(42))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ torch::lazy::Value XLATensor::GetIrValueForTensor(
return ScalarOp(std::move(value),
MakeXlaPrimitiveType(tensor.scalar_type(), &device));
}
data = XLAGraphExecutor::Get()->GetDeviceData(tensor, device);
data = XLAGraphExecutor::Get()->GetDeviceData(tensor.cpu(), device);
read_only = true;
} else {
TORCH_LAZY_TIMED("IrValueTensorToXlaData");
Expand Down

0 comments on commit a80c1e7

Please sign in to comment.