diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index bd57e18b1a3..e0d917db684 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -192,6 +192,7 @@ supported: - floor_divide - fmod.Scalar - fmod.Tensor + - full - gather - gelu - gelu_backward diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index a5db45185f7..e95651604d7 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -139,8 +139,7 @@ TEST_F(AtenXlaTensorTest, TestFull) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::fill_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::full", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestFullLike) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index fa2197093dc..f6d73f8ca1d 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1331,6 +1331,24 @@ at::Tensor XLANativeFunctions::fmod(const at::Tensor& self, }); } +at::Tensor XLANativeFunctions::full(at::IntArrayRef size, + const at::Scalar& fill_value, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + TORCH_LAZY_FN_COUNTER("xla::"); + // Fall back to CPU if layout or pin_memory are not default + if (layout.value_or(at::Layout::Strided) != at::Layout::Strided || + pin_memory.value_or(false)) { + return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call( + size, fill_value, dtype, layout, device, pin_memory); + } + return bridge::AtenFromXlaTensor(tensor_methods::full( + absl::Span(size), fill_value, + GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); +} + at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim, const at::Tensor& index, bool /* sparse_grad */) {