From d390f10903f1b3238ab6037c1d9a087144000b4d Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Fri, 8 Mar 2024 23:20:22 +0800 Subject: [PATCH 1/2] Add aten::clone and simple cases Signed-off-by: Feng Yuan --- src/aten/TensorFactories.cpp | 5 +++++ test/python/examples/test_copy.py | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 test/python/examples/test_copy.py diff --git a/src/aten/TensorFactories.cpp b/src/aten/TensorFactories.cpp index 8504f9a80..835ecd1d0 100644 --- a/src/aten/TensorFactories.cpp +++ b/src/aten/TensorFactories.cpp @@ -28,9 +28,14 @@ Tensor empty_strided_xpu(IntArrayRef size, IntArrayRef stride, c10::optional memory_format) { + return at::native::clone(self, memory_format); +} + TORCH_LIBRARY_IMPL(aten, XPU, m) { m.impl(TORCH_SELECTIVE_NAME("aten::empty.memory_format"), TORCH_FN(at::native::empty_xpu)); m.impl(TORCH_SELECTIVE_NAME("aten::empty_strided"), TORCH_FN(at::native::empty_strided_xpu)); + m.impl(TORCH_SELECTIVE_NAME("aten::clone"), TORCH_FN(at::native::clone_xpu)); } } // namespace at::native diff --git a/test/python/examples/test_copy.py b/test/python/examples/test_copy.py new file mode 100644 index 000000000..d5df007fb --- /dev/null +++ b/test/python/examples/test_copy.py @@ -0,0 +1,21 @@ +import torch +from torch.testing._internal.common_utils import TestCase + +cpu_device = torch.device("cpu") +xpu_device = torch.device("xpu") + + +class TestSimpleCopy(TestCase): + def test_copy_and_clone(self, dtype=torch.float): + a_cpu = torch.randn(16, 64, 28, 28) + b_cpu = torch.randn(16, 64, 28, 28) + a_xpu = a_cpu.to(xpu_device) + b_xpu = b_cpu.to(xpu_device) + # naive + b_cpu.copy_(a_cpu) + b_xpu.copy_(a_xpu) + self.assertEqual(b_cpu, b_xpu.to(cpu_device)) + # clone + permutation + b_cpu = a_cpu.clone(memory_format=torch.channels_last) + b_xpu = a_xpu.clone(memory_format=torch.channels_last) + self.assertEqual(b_cpu, b_xpu.to(cpu_device)) From c4b2277406363692a615ae47cccadad126dbbaa8 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Sat, 9 Mar 2024 01:00:20 +0800 Subject: [PATCH 2/2] Fix Loops legacy code-path Signed-off-by: Feng Yuan --- src/aten/sycl/Loops.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aten/sycl/Loops.h b/src/aten/sycl/Loops.h index 4c571a6fb..ea36862fd 100644 --- a/src/aten/sycl/Loops.h +++ b/src/aten/sycl/Loops.h @@ -43,13 +43,13 @@ inline void elementwise_kernel_helper(func_t f, policy_t policy) { template struct ElementwiseKernel { void operator()(sycl::nd_item<1> item) const { - int grpsz = item.get_local_range(0); + int glbsz = item.get_global_range(0); int gid = item.get_global_linear_id(); #pragma unroll for (int i = 0; i < vec_size; i++) { if (gid < numel_) { f_(gid); - gid += grpsz; + gid += glbsz; } } };