Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lowering as_strided errors for input tensors smaller than size-stride specs. #5719

Closed
ysiraichi opened this issue Oct 21, 2023 · 14 comments · Fixed by #5914
Closed

Lowering as_strided errors for input tensors smaller than size-stride specs. #5719

ysiraichi opened this issue Oct 21, 2023 · 14 comments · Fixed by #5914

Comments

@ysiraichi
Copy link
Collaborator

🐛 Bug

The following usage of as_strided errors when lowering.

x = torch.randn(20, device=xm.xla_device())
y = x[10:]
z = y.as_strided((20,), (1,), 0)
print(z)
Traceback (most recent call last):
  File "bug-as-strided.py", line 10, in <module>
    print(z)
  File "torch/_tensor.py", line 442, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "torch/_tensor_str.py", line 664, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "torch/_tensor_str.py", line 430, in _str_intern
    self = self.to("cpu")
RuntimeError: Error while lowering: [] aten::as_strided, xla_shape=f32[20]{0}, size=(20), stride=(1), storage_offset=0
Error: torch_xla/csrc/ops/as_strided.cpp:33 : Check failed: storage_offset + slice_size <= input_element_count (20 vs. 10)
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()

        torch_xla::AsStrided::Lower(torch_xla::LoweringContext*) const
        torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
        torch_xla::LoweringContext::LoweringContext(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, torch::lazy::BackendDevice, c10::ArrayRef<torch::lazy::Node const*>, std::unordered_map<torch::lazy::Node const*, torch::lazy::Util::EmitStatus, std::hash<torch::lazy::Node const*>, std::equal_to<torch::lazy::Node const*>, std::allocator<std::pair<torch::lazy::Node const* const, torch::lazy::Util::EmitStatus> > >)

This error shows up when trying to execute hf_Reformer from Torchbench, using openxla as backend. As far as I understand, the problem is that AOTAutograd is calling as_strided -- not entirely sure why. This problem seems to be related to the limitations of reshape functions in XLA, as suggested in #2964.

Expected behavior

I would expect it to break earlier, say, when XLANativeFunctions::as_strided is being executed, instead of when it gets to the lowering part. Or, maybe better than that, we could fallback to CPU while issueing a warning that "as_strided is creating a copy, which may not be optimal...".

Environment

PyTorch/XLA: c9a1324 (Oct 3)

@ysiraichi
Copy link
Collaborator Author

@JackCaoG any thoughts?

@JackCaoG
Copy link
Collaborator

Right, I took a look at the logic at

XLA_CHECK_LE(storage_offset + slice_size, input_element_count);

We should be able to check this in the higher level. Can we check where does this as_strieded node being created? I suspect it is in

at::Tensor XLANativeFunctions::as_strided_copy(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
storage_offset);
}
return bridge::AtenFromXlaTensor(tensor_methods::as_strided(
self_tensor, std::move(xsize), std::move(xstride),
XlaHelpers::I64Optional(storage_offset)));
}

Maybe we just need to expand AsStrided::StrideIsSupported since it should fallback here.

@ysiraichi
Copy link
Collaborator Author

It's exactly there!

@JackCaoG
Copy link
Collaborator

ok let's make StrideIsSupported catch this fallback early then.

@ysiraichi
Copy link
Collaborator Author

@JackCaoG I don't think this will work. The fallback will still fail because (in the example) only y (not x) will be materialized on CPU. i.e. the storage of y will not contain as many bytes as x's storage.

@ysiraichi
Copy link
Collaborator Author

I mean: it does surface the error to the graph building phase (instead of lowering). But the code still doesn't work because of the reason mentioned above. Do you still think this solution should be merged?

@JackCaoG
Copy link
Collaborator

That's weird.. I thought during fallback we move all XLA tensors to CPU, perform the operation on CPU and then move the result back to the XLA, is that not the case?

@ysiraichi
Copy link
Collaborator Author

As far as I understand, if we execute tensor.as_strided(...) we will move tensor to CPU, correct? If so:

# x is created with capacity for 20 elements
x = torch.randn(20, device=xm.xla_device())

# y has shape (10,), but its storage is the same as x
y = x[10:]

# y is moved to CPU
# y has shape (10,), so a CPU tensor of shape (10,) will be created
#     - i.e. the CPU tensor will have a storage capacity for 10 elements
#     - can't reshape it to (20,), since the storage doesn't have this capacity
z = y.as_strided((20,), (1,), 0)

@JackCaoG
Copy link
Collaborator

oh ok I see the problem... I guess it is one of those cases that's very hard to implement correctly. @bdhirsh Given that now functionization lives upstream, in the case of the

x = torch.randn(20, device=xm.xla_device())
y = x[10:]
z = y.as_strided((20,), (1,), 0)

can we just expand the y form [10] to [20]? I am a bit confuse what's the expected behavior of this as_strided_copy after functionization..

@lezcano
Copy link
Collaborator

lezcano commented Nov 6, 2023

In your example, you are explicitly asking for an offset of 0 when creating y, so that's exactly what you are getting.

A good way to think about as_strided is as not acting on the tensor, but on the underlying storage. In this case, the underlying storage of y allows for a size of (20,), stride of (1,) and an offset of 0 (that's exactly x).

Since as_strided works on the storage, and not in the tensor itself, we would need to move the base tensor (or the whole storage) to CPU. After passing the storage to CPU, we also need the offset of the initial tensor, as this is the only other piece of data that is used to implement as_strided.
https://github.com/pytorch/pytorch/blob/2bc1378d7be563fa9b3050bb0e0fefd6e55a9e81/aten/src/ATen/native/TensorShape.cpp#L1165-L1172
In many cases we may be able to prove statically that we don't need to copy the whole thing, but yeah.

Now, this is the sort of thing that will surely have plenty of edge cases when we mix a few views with a few in-place ops... Surely @bdhirsh will have a better understanding of all these.

@lezcano
Copy link
Collaborator

lezcano commented Nov 6, 2023

@ysiraichi
Copy link
Collaborator Author

I feel like this is a more complex version of falling back the whole computation of z to CPU. Maybe we could introduce a pass for identifying this sub-graph. What do you think?

@ysiraichi
Copy link
Collaborator Author

Nevermind. I think I finally understood what you were saying. That said, I don't think we even need to move things to the CPU. We only need to apply the operation on the base tensor.

@bdhirsh
Copy link
Collaborator

bdhirsh commented Nov 20, 2023

We have a tentative PoR to get AOTAutograd to stop emitting as_strided() in all cases (that I'm hoping to get around to early next year). There isn't a detailed design doc yet (I'll have one before working on it) but high level idea written down here with Ed: https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants