Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix as_strided for inputs smaller than the arguments specification. #5914

Merged
merged 7 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,16 @@ def from_tensors(self, tensors):
self.assertEqual(xdata.batch_sizes.device, torch.device('cpu'))
self.assertEqual(xdata.data.device, xla_device)

def test_as_strided_input_larger(self):
size = (5, 5)
device = xm.xla_device()

a = torch.ones(size, device=device)
small_a = a[:, ::2]
former_a = small_a.as_strided(size, (5, 1), 0)

self.assertEqual(a, former_a)


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,5 +440,19 @@ std::vector<at::Tensor> CreateXlaTensors(
return xtensors;
}

const at::Tensor& GetRootBase(const at::Tensor& tensor) {
auto xla_tensor = TryGetXlaTensor(tensor);
if (xla_tensor && xla_tensor->Base().defined()) {
return GetRootBase(xla_tensor->Base());
} else {
return tensor;
}
}

XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base) {
tensor->SetBase(GetRootBase(base));
return tensor;
}

} // namespace bridge
} // namespace torch_xla
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ auto TupleAtenFromXlaTensors(const std::vector<XLATensorPtr>& tensors) {
return TupleAtenFromXlaTensorsImpl(tensors, std::make_index_sequence<N>{});
}

// Returns the deepest base tensor for a given tensor.
// If the base tensor is not defined, returns the tensor itself.
const at::Tensor& GetRootBase(const at::Tensor& tensor);
// Sets the base tensor of a given XLATensor. Convenient function
// to be used when returning tensors.
XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we ever expect base to be on non-xla device? If not can we add an explict check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I don't think so, since we got to a XLA dispatched kernel. Will add the check.


} // namespace bridge
} // namespace torch_xla

Expand Down
30 changes: 22 additions & 8 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,12 @@ at::Tensor XLANativeFunctions::as_strided_copy(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
// Retrieve the base tensor, if there's one.
// This function actually operates on the tensor's storage. Since XLA does not
// expose the actual storage, we use the originally allocated tensor.
const auto& base = bridge::GetXlaTensor(self)->Base();
const auto& tensor = base.defined() ? base : self;
ysiraichi marked this conversation as resolved.
Show resolved Hide resolved
XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor);
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride,
Expand All @@ -703,9 +708,14 @@ at::Tensor XLANativeFunctions::as_strided_copy(
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
storage_offset);
}
return bridge::AtenFromXlaTensor(tensor_methods::as_strided(
self_tensor, std::move(xsize), std::move(xstride),
XlaHelpers::I64Optional(storage_offset)));
// Sets the base tensor as tensor.
// Even though this function copies (without aliasing) tensor, it's still
// treated as a view function in the functionalization layer.
return bridge::AtenFromXlaTensor(bridge::SetBaseTensor(
tensor_methods::as_strided(self_tensor, std::move(xsize),
std::move(xstride),
XlaHelpers::I64Optional(storage_offset)),
tensor));
}

at::Tensor XLANativeFunctions::as_strided_scatter(
Expand Down Expand Up @@ -2791,8 +2801,10 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim,
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
return bridge::AtenFromXlaTensor(tensor_methods::slice(
bridge::GetXlaTensor(self), dim, start_val, end_val, step));
return bridge::AtenFromXlaTensor(bridge::SetBaseTensor(
tensor_methods::slice(bridge::GetXlaTensor(self), dim, start_val, end_val,
step),
self));
}

at::Tensor XLANativeFunctions::slice_scatter(
Expand Down Expand Up @@ -3724,13 +3736,15 @@ at::Tensor XLANativeFunctions::as_strided(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
const auto& base = bridge::GetXlaTensor(self)->Base();
const auto& tensor = base.defined() ? base : self;
XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor);
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(tensor, size, stride,
storage_offset);
}
return bridge::AtenFromXlaTensor(tensor_methods::as_strided(
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ XLATensor::XLATensor(std::shared_ptr<View> view,
XLATensor::XLATensor(std::shared_ptr<Data> data)
: torch::lazy::LazyTensor(data),
data_(std::move(data)),
storage_(c10::Storage({}, 0,
c10::DataPtr(nullptr, bridge::XlaDeviceToAtenDevice(
data_->device)))) {}
storage_(c10::Storage(
{}, 0,
c10::DataPtr(nullptr, bridge::XlaDeviceToAtenDevice(data_->device)))),
base_() {}

auto XLATensor::data() const -> const std::shared_ptr<Data>& {
XLA_CHECK(data_ != nullptr) << "Trying to access a null cursor";
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ class XLATensor : public torch::lazy::LazyTensor {
void SetStorage(const c10::Storage& storage) { storage_ = storage; }
const c10::Storage& Storage() const { return storage_; }

void SetBase(const at::Tensor& base) { base_ = base; }
const at::Tensor& Base() const { return base_; }

int64_t GetHandle() const;

// Override to enable SPMD.
Expand Down Expand Up @@ -337,6 +340,11 @@ class XLATensor : public torch::lazy::LazyTensor {
// points to the same storage, and thus alias of each other.
// FIXME(alanwaketan): Remove this once we have functionalization (bdhirsh).
c10::Storage storage_;
// Base tensor for view and view_copy operations. This is used mainly for
// operations such as as_strided, which operates on the allocated storage.
// Since XLATensor doesn't actually use the storage, we have to run the
ysiraichi marked this conversation as resolved.
Show resolved Hide resolved
// operation on the originally created tensor.
at::Tensor base_;
};

} // namespace torch_xla
Expand Down