Skip to content

Commit

Permalink
Lower AtenFull op (pytorch#5781)
Browse files Browse the repository at this point in the history
* lower full

* update test for full op

* formatting
  • Loading branch information
danielvegamyhre authored and chunnienc committed Dec 14, 2023
1 parent e89521e commit d21c024
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ supported:
- floor_divide
- fmod.Scalar
- fmod.Tensor
- full
- gather
- gelu
- gelu_backward
Expand Down
3 changes: 1 addition & 2 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ TEST_F(AtenXlaTensorTest, TestFull) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::fill_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::full", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestFullLike) {
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,24 @@ at::Tensor XLANativeFunctions::fmod(const at::Tensor& self,
});
}

at::Tensor XLANativeFunctions::full(at::IntArrayRef size,
const at::Scalar& fill_value,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("xla::");
// Fall back to CPU if layout or pin_memory are not default
if (layout.value_or(at::Layout::Strided) != at::Layout::Strided ||
pin_memory.value_or(false)) {
return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(full)>::call(
size, fill_value, dtype, layout, device, pin_memory);
}
return bridge::AtenFromXlaTensor(tensor_methods::full(
absl::Span<const int64_t>(size), fill_value,
GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype)));
}

at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
bool /* sparse_grad */) {
Expand Down

0 comments on commit d21c024

Please sign in to comment.