From 6922d412acba13502f44c5113422d6afe523c5d8 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre <105610547+danielvegamyhre@users.noreply.github.com> Date: Mon, 13 Nov 2023 16:18:35 -0800 Subject: [PATCH] Lower AtenFull op (#5781) * lower full * update test for full op * formatting --- codegen/xla_native_functions.yaml | 1 + test/cpp/test_aten_xla_tensor_1.cpp | 3 +-- torch_xla/csrc/aten_xla_type.cpp | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index bd57e18b1a3b..e0d917db6849 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -192,6 +192,7 @@ supported: - floor_divide - fmod.Scalar - fmod.Tensor + - full - gather - gelu - gelu_backward diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index a5db45185f79..e95651604d7c 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -139,8 +139,7 @@ TEST_F(AtenXlaTensorTest, TestFull) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::fill_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::full", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestFullLike) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index dfa7d74f7888..1fb8390db067 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1332,6 +1332,24 @@ at::Tensor XLANativeFunctions::fmod(const at::Tensor& self, }); } +at::Tensor XLANativeFunctions::full(at::IntArrayRef size, + const at::Scalar& fill_value, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + TORCH_LAZY_FN_COUNTER("xla::"); + // Fall back to CPU if layout or pin_memory are not default + if (layout.value_or(at::Layout::Strided) != at::Layout::Strided || + pin_memory.value_or(false)) { + return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call( + size, fill_value, dtype, layout, device, pin_memory); + } + return bridge::AtenFromXlaTensor(tensor_methods::full( + absl::Span(size), fill_value, + GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); +} + at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, const at::Tensor& index, bool /* sparse_grad */) {