Skip to content

Commit

Permalink
Fix as_strided for inputs smaller than the arguments specification. (
Browse files Browse the repository at this point in the history
…#5914)

* Add test.

* Create `base_` tensor for views.

* Use base tensor in `as_strided` operation.

* Set base tensor of `as_strided`.

* Fix lint errors.

* Fix for disabled functionalization.

* Address review.
  • Loading branch information
ysiraichi authored Nov 28, 2023
1 parent c03afb1 commit 4ac255a
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 11 deletions.
10 changes: 10 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,16 @@ def from_tensors(self, tensors):
self.assertEqual(xdata.batch_sizes.device, torch.device('cpu'))
self.assertEqual(xdata.data.device, xla_device)

def test_as_strided_input_larger(self):
size = (5, 5)
device = xm.xla_device()

a = torch.ones(size, device=device)
small_a = a[:, ::2]
former_a = small_a.as_strided(size, (5, 1), 0)

self.assertEqual(a, former_a)


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,5 +440,21 @@ std::vector<at::Tensor> CreateXlaTensors(
return xtensors;
}

const at::Tensor& GetRootBase(const at::Tensor& tensor) {
auto xla_tensor = TryGetXlaTensor(tensor);
if (xla_tensor && xla_tensor->Base().defined()) {
return GetRootBase(xla_tensor->Base());
} else {
return tensor;
}
}

XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base) {
XLA_CHECK(base.device().is_xla())
<< "base tensor on unexpected device: " << base.device();
tensor->SetBase(GetRootBase(base));
return tensor;
}

} // namespace bridge
} // namespace torch_xla
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ auto TupleAtenFromXlaTensors(const std::vector<XLATensorPtr>& tensors) {
return TupleAtenFromXlaTensorsImpl(tensors, std::make_index_sequence<N>{});
}

// Returns the deepest base tensor for a given tensor.
// If the base tensor is not defined, returns the tensor itself.
const at::Tensor& GetRootBase(const at::Tensor& tensor);
// Sets the base tensor of a given XLATensor. Convenient function
// to be used when returning tensors.
XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base);

} // namespace bridge
} // namespace torch_xla

Expand Down
30 changes: 22 additions & 8 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,12 @@ at::Tensor XLANativeFunctions::as_strided_copy(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
// Retrieve the base tensor, if there's one.
// This function actually operates on the tensor's storage. Since XLA does not
// expose the actual storage, we use the originally allocated tensor.
const at::Tensor& base = bridge::GetXlaTensor(self)->Base();
const at::Tensor& tensor = base.defined() ? base : self;
XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor);
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride,
Expand All @@ -703,9 +708,14 @@ at::Tensor XLANativeFunctions::as_strided_copy(
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
storage_offset);
}
return bridge::AtenFromXlaTensor(tensor_methods::as_strided(
self_tensor, std::move(xsize), std::move(xstride),
XlaHelpers::I64Optional(storage_offset)));
// Sets the base tensor as tensor.
// Even though this function copies (without aliasing) tensor, it's still
// treated as a view function in the functionalization layer.
return bridge::AtenFromXlaTensor(bridge::SetBaseTensor(
tensor_methods::as_strided(self_tensor, std::move(xsize),
std::move(xstride),
XlaHelpers::I64Optional(storage_offset)),
tensor));
}

at::Tensor XLANativeFunctions::as_strided_scatter(
Expand Down Expand Up @@ -2791,8 +2801,10 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim,
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
return bridge::AtenFromXlaTensor(tensor_methods::slice(
bridge::GetXlaTensor(self), dim, start_val, end_val, step));
return bridge::AtenFromXlaTensor(bridge::SetBaseTensor(
tensor_methods::slice(bridge::GetXlaTensor(self), dim, start_val, end_val,
step),
self));
}

at::Tensor XLANativeFunctions::slice_scatter(
Expand Down Expand Up @@ -3724,13 +3736,15 @@ at::Tensor XLANativeFunctions::as_strided(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
const auto& base = bridge::GetXlaTensor(self)->Base();
const auto& tensor = base.defined() ? base : self;
XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor);
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(tensor, size, stride,
storage_offset);
}
return bridge::AtenFromXlaTensor(tensor_methods::as_strided(
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ XLATensor::XLATensor(std::shared_ptr<View> view,
XLATensor::XLATensor(std::shared_ptr<Data> data)
: torch::lazy::LazyTensor(data),
data_(std::move(data)),
storage_(c10::Storage({}, 0,
c10::DataPtr(nullptr, bridge::XlaDeviceToAtenDevice(
data_->device)))) {}
storage_(c10::Storage(
{}, 0,
c10::DataPtr(nullptr, bridge::XlaDeviceToAtenDevice(data_->device)))),
base_() {}

auto XLATensor::data() const -> const std::shared_ptr<Data>& {
XLA_CHECK(data_ != nullptr) << "Trying to access a null cursor";
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ class XLATensor : public torch::lazy::LazyTensor {
void SetStorage(const c10::Storage& storage) { storage_ = storage; }
const c10::Storage& Storage() const { return storage_; }

void SetBase(const at::Tensor& base) { base_ = base; }
const at::Tensor& Base() const { return base_; }

int64_t GetHandle() const;

// Override to enable SPMD.
Expand Down Expand Up @@ -337,6 +340,11 @@ class XLATensor : public torch::lazy::LazyTensor {
// points to the same storage, and thus alias of each other.
// FIXME(alanwaketan): Remove this once we have functionalization (bdhirsh).
c10::Storage storage_;
// Base tensor for view and view_copy operations. This is used mainly for
// operations such as as_strided, which operates on the allocated storage.
// Since XLATensor doesn't actually expose the storage, we have to run the
// operation on the originally created tensor.
at::Tensor base_;
};

} // namespace torch_xla
Expand Down

0 comments on commit 4ac255a

Please sign in to comment.