Skip to content

Commit

Permalink
Revert "Revert updating mlir_native_functions.cpp signature (#1281)"
Browse files Browse the repository at this point in the history
This reverts commit a1ace06.
  • Loading branch information
henrytwo committed Aug 26, 2022
1 parent 0e3ddba commit 2d835e8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
6 changes: 0 additions & 6 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,6 @@
"NewOnesModuleFloat3D_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"OnesLikeModule_defaultDtype",
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
Expand Down
29 changes: 22 additions & 7 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,14 @@ at::Tensor LazyNativeFunctions::_to_copy(
};

at::Tensor LazyNativeFunctions::empty(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
at::SymIntArrayRef sym_size,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
// TODO: support this directly
auto size = c10::asIntArrayRefSlow(sym_size);
const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType();
at::TensorOptions options = at::TensorOptions()
.device(c10::Device(device_type))
Expand All @@ -317,8 +321,9 @@ at::Tensor LazyNativeFunctions::empty(
// See Note [Lazy Tensor Functionalization]
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
c10::DispatchKey::Functionalize)) {
// Invariant: if the functionalization key is in the exclude set, then we're expected
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
// Invariant: if the functionalization key is in the exclude set, then we're
// expected to return an ordinary tensor, which will be "lifted" into a
// functional wrapper later.
return tensor;
} else {
auto wrapped = at::functionalization::impl::to_functional_tensor(tensor);
Expand All @@ -331,7 +336,13 @@ at::Tensor LazyNativeFunctions::empty_strided(
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER("lazy::");
at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt);
at::Tensor t = empty(
c10::SymIntArrayRef::fromIntArrayRef(size),
dtype,
layout,
device,
pin_memory,
c10::nullopt);
return t.as_strided(size, stride, /*storage_offset=*/0);
}

Expand All @@ -350,7 +361,8 @@ LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) {
at::Tensor LazyNativeFunctions::_unsafe_view(
const at::Tensor& self, at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
return LazyNativeFunctions::view_copy(self, size);
return LazyNativeFunctions::view_copy(
self, c10::SymIntArrayRef::fromIntArrayRef(size));
}

// This is needed by the torch.tensor constructor.
Expand Down Expand Up @@ -386,7 +398,10 @@ at::Tensor LazyNativeFunctions::new_empty_strided(
}

at::Tensor LazyNativeFunctions::narrow_copy(
const at::Tensor& self, int64_t dim, int64_t start, int64_t length) {
const at::Tensor& self,
int64_t dim,
c10::SymInt start,
c10::SymInt length) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
narrow_copy)>::call(self, dim, start, length);
}
Expand Down

0 comments on commit 2d835e8

Please sign in to comment.