Skip to content

Commit

Permalink
Add aten::clone and simple cases
Browse files Browse the repository at this point in the history
* Add aten::clone and simple cases

Signed-off-by: Feng Yuan <feng1.yuan@intel.com>

* Fix Loops legacy code-path

Signed-off-by: Feng Yuan <feng1.yuan@intel.com>

---------

Signed-off-by: Feng Yuan <feng1.yuan@intel.com>
  • Loading branch information
fengyuan14 authored Mar 8, 2024
1 parent 6340065 commit 9cb104a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/aten/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ Tensor empty_strided_xpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
return result;
}

Tensor clone_xpu(const Tensor& self, c10::optional<MemoryFormat> memory_format) {
return at::native::clone(self, memory_format);
}

TORCH_LIBRARY_IMPL(aten, XPU, m) {
m.impl(TORCH_SELECTIVE_NAME("aten::empty.memory_format"), TORCH_FN(at::native::empty_xpu));
m.impl(TORCH_SELECTIVE_NAME("aten::empty_strided"), TORCH_FN(at::native::empty_strided_xpu));
m.impl(TORCH_SELECTIVE_NAME("aten::clone"), TORCH_FN(at::native::clone_xpu));
}

} // namespace at::native
4 changes: 2 additions & 2 deletions src/aten/sycl/Loops.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
template <int vec_size, typename func_t>
struct ElementwiseKernel {
void operator()(sycl::nd_item<1> item) const {
int grpsz = item.get_local_range(0);
int glbsz = item.get_global_range(0);
int gid = item.get_global_linear_id();
#pragma unroll
for (int i = 0; i < vec_size; i++) {
if (gid < numel_) {
f_(gid);
gid += grpsz;
gid += glbsz;
}
}
};
Expand Down
21 changes: 21 additions & 0 deletions test/python/examples/test_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from torch.testing._internal.common_utils import TestCase

cpu_device = torch.device("cpu")
xpu_device = torch.device("xpu")


class TestSimpleCopy(TestCase):
def test_copy_and_clone(self, dtype=torch.float):
a_cpu = torch.randn(16, 64, 28, 28)
b_cpu = torch.randn(16, 64, 28, 28)
a_xpu = a_cpu.to(xpu_device)
b_xpu = b_cpu.to(xpu_device)
# naive
b_cpu.copy_(a_cpu)
b_xpu.copy_(a_xpu)
self.assertEqual(b_cpu, b_xpu.to(cpu_device))
# clone + permutation
b_cpu = a_cpu.clone(memory_format=torch.channels_last)
b_xpu = a_xpu.clone(memory_format=torch.channels_last)
self.assertEqual(b_cpu, b_xpu.to(cpu_device))

0 comments on commit 9cb104a

Please sign in to comment.