From 4ac255ae4db09ed321a64c8f03402a2019ce4669 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 28 Nov 2023 16:23:31 -0300 Subject: [PATCH] Fix `as_strided` for inputs smaller than the arguments specification. (#5914) * Add test. * Create `base_` tensor for views. * Use base tensor in `as_strided` operation. * Set base tensor of `as_strided`. * Fix lint errors. * Fix for disabled functionalization. * Address review. --- test/test_operations.py | 10 ++++++++++ torch_xla/csrc/aten_xla_bridge.cpp | 16 ++++++++++++++++ torch_xla/csrc/aten_xla_bridge.h | 7 +++++++ torch_xla/csrc/aten_xla_type.cpp | 30 ++++++++++++++++++++++-------- torch_xla/csrc/tensor.cpp | 7 ++++--- torch_xla/csrc/tensor.h | 8 ++++++++ 6 files changed, 67 insertions(+), 11 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 4e8ebcedee4..fbac40dbba9 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2296,6 +2296,16 @@ def from_tensors(self, tensors): self.assertEqual(xdata.batch_sizes.device, torch.device('cpu')) self.assertEqual(xdata.data.device, xla_device) + def test_as_strided_input_larger(self): + size = (5, 5) + device = xm.xla_device() + + a = torch.ones(size, device=device) + small_a = a[:, ::2] + former_a = small_a.as_strided(size, (5, 1), 0) + + self.assertEqual(a, former_a) + if __name__ == '__main__': torch.set_default_dtype(torch.float32) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 009a3b84550..e30e83ee3d1 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -440,5 +440,21 @@ std::vector CreateXlaTensors( return xtensors; } +const at::Tensor& GetRootBase(const at::Tensor& tensor) { + auto xla_tensor = TryGetXlaTensor(tensor); + if (xla_tensor && xla_tensor->Base().defined()) { + return GetRootBase(xla_tensor->Base()); + } else { + return tensor; + } +} + +XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base) { + XLA_CHECK(base.device().is_xla()) + << "base tensor on unexpected device: " << base.device(); + tensor->SetBase(GetRootBase(base)); + return tensor; +} + } // namespace bridge } // namespace torch_xla diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index bfe4e9e614d..7d6188809c0 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -148,6 +148,13 @@ auto TupleAtenFromXlaTensors(const std::vector& tensors) { return TupleAtenFromXlaTensorsImpl(tensors, std::make_index_sequence{}); } +// Returns the deepest base tensor for a given tensor. +// If the base tensor is not defined, returns the tensor itself. +const at::Tensor& GetRootBase(const at::Tensor& tensor); +// Sets the base tensor of a given XLATensor. Convenient function +// to be used when returning tensors. +XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base); + } // namespace bridge } // namespace torch_xla diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 466645af7c0..80a7225d735 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -694,7 +694,12 @@ at::Tensor XLANativeFunctions::as_strided_copy( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + // Retrieve the base tensor, if there's one. + // This function actually operates on the tensor's storage. Since XLA does not + // expose the actual storage, we use the originally allocated tensor. + const at::Tensor& base = bridge::GetXlaTensor(self)->Base(); + const at::Tensor& tensor = base.defined() ? base : self; + XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, @@ -703,9 +708,14 @@ at::Tensor XLANativeFunctions::as_strided_copy( &xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride, storage_offset); } - return bridge::AtenFromXlaTensor(tensor_methods::as_strided( - self_tensor, std::move(xsize), std::move(xstride), - XlaHelpers::I64Optional(storage_offset))); + // Sets the base tensor as tensor. + // Even though this function copies (without aliasing) tensor, it's still + // treated as a view function in the functionalization layer. + return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( + tensor_methods::as_strided(self_tensor, std::move(xsize), + std::move(xstride), + XlaHelpers::I64Optional(storage_offset)), + tensor)); } at::Tensor XLANativeFunctions::as_strided_scatter( @@ -2791,8 +2801,10 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; - return bridge::AtenFromXlaTensor(tensor_methods::slice( - bridge::GetXlaTensor(self), dim, start_val, end_val, step)); + return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( + tensor_methods::slice(bridge::GetXlaTensor(self), dim, start_val, end_val, + step), + self)); } at::Tensor XLANativeFunctions::slice_scatter( @@ -3724,13 +3736,15 @@ at::Tensor XLANativeFunctions::as_strided( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + const auto& base = bridge::GetXlaTensor(self)->Base(); + const auto& tensor = base.defined() ? base : self; + XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, storage_offset.value_or(0))) { return at::native::call_fallback_fn< - &xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride, + &xla_cpu_fallback, ATEN_OP(as_strided)>::call(tensor, size, stride, storage_offset); } return bridge::AtenFromXlaTensor(tensor_methods::as_strided( diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 4a97aad68b7..7fdd998174c 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -139,9 +139,10 @@ XLATensor::XLATensor(std::shared_ptr view, XLATensor::XLATensor(std::shared_ptr data) : torch::lazy::LazyTensor(data), data_(std::move(data)), - storage_(c10::Storage({}, 0, - c10::DataPtr(nullptr, bridge::XlaDeviceToAtenDevice( - data_->device)))) {} + storage_(c10::Storage( + {}, 0, + c10::DataPtr(nullptr, bridge::XlaDeviceToAtenDevice(data_->device)))), + base_() {} auto XLATensor::data() const -> const std::shared_ptr& { XLA_CHECK(data_ != nullptr) << "Trying to access a null cursor"; diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 83db2e95df6..e0c654ee44a 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -277,6 +277,9 @@ class XLATensor : public torch::lazy::LazyTensor { void SetStorage(const c10::Storage& storage) { storage_ = storage; } const c10::Storage& Storage() const { return storage_; } + void SetBase(const at::Tensor& base) { base_ = base; } + const at::Tensor& Base() const { return base_; } + int64_t GetHandle() const; // Override to enable SPMD. @@ -337,6 +340,11 @@ class XLATensor : public torch::lazy::LazyTensor { // points to the same storage, and thus alias of each other. // FIXME(alanwaketan): Remove this once we have functionalization (bdhirsh). c10::Storage storage_; + // Base tensor for view and view_copy operations. This is used mainly for + // operations such as as_strided, which operates on the allocated storage. + // Since XLATensor doesn't actually expose the storage, we have to run the + // operation on the originally created tensor. + at::Tensor base_; }; } // namespace torch_xla