Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo (openxla) fails when returning tensor.expand. #5837

Closed
ysiraichi opened this issue Nov 21, 2023 · 5 comments
Closed

Dynamo (openxla) fails when returning tensor.expand. #5837

ysiraichi opened this issue Nov 21, 2023 · 5 comments
Labels

Comments

@ysiraichi
Copy link
Collaborator

🐛 Bug

In the following code, foo is a function that tries to expand the input tensor x [3] into a [2, 3] tensor.

import torch
import torch_xla.core.xla_model as xm

@torch.compile(backend="openxla")
def foo(x):
    return x.expand(2, *x.shape)

x = torch.arange(3, device=xm.xla_device())
print(foo(x))
Traceback (most recent call last):
  File "examples/bug-expand.py", line 15, in <module>
    print(foo(x))
  File "torch/_tensor.py", line 442, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "torch/_tensor_str.py", line 664, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "torch/_tensor_str.py", line 430, in _str_intern
    self = self.to("cpu")
RuntimeError: Error while lowering: [] aten::as_strided, xla_shape=s64[2,3]{1,0}, size=(2, 3), stride=(3, 1), storage_offset=0
Error: torch_xla/csrc/ops/as_strided.cpp:26 : Check failed: storage_offset + slice_size <= input_element_count (6 vs. 3)
Full Error Dump
Traceback (most recent call last):
  File "examples/bug-expand.py", line 15, in <module>
    print(foo(x))
  File "torch/_tensor.py", line 442, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "torch/_tensor_str.py", line 664, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "torch/_tensor_str.py", line 430, in _str_intern
    self = self.to("cpu")
RuntimeError: Error while lowering: [] aten::as_strided, xla_shape=s64[2,3]{1,0}, size=(2, 3), stride=(3, 1), storage_offset=0
Error: torch_xla/csrc/ops/as_strided.cpp:26 : Check failed: storage_offset + slice_size <= input_element_count (6 vs. 3)
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()

        torch_xla::AsStrided::Lower(torch_xla::LoweringContext*) const
        torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
        torch_xla::LoweringContext::LoweringContext(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, torch::lazy::BackendDevice, c10::ArrayRef<torch::lazy::Node const*>, std::unordered_map<torch::lazy::Node const*, torch::lazy::Util::EmitStatus, std::hash<torch::lazy::Node const*>, std::equal_to<torch::lazy::Node const*>, std::allocator<std::pair<torch::lazy::Node const* const, torch::lazy::Util::EmitStatus> > >)
        torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, absl::lts_20230125::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230125::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230125::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
        torch_xla::XLATensor::ApplyPendingGraph()
        torch_xla::XLATensor::GetXlaData()
        torch_xla::XLATensor::ToTensor(bool)
        torch_xla::XLANativeFunctions::_to_copy(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)


        at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)

        at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)


        at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)


        at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)
        at::native::to(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, bool, c10::optional<c10::MemoryFormat>)

        at::_ops::to_dtype_layout::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, bool, c10::optional<c10::MemoryFormat>)
        at::Tensor::to(c10::TensorOptions, bool, bool, c10::optional<c10::MemoryFormat>) const



        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall


        PyObject_Str
        PyFile_WriteObject


        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain

        __libc_start_main

*** End stack trace ***

Affected Benchmark

hf_Reformer

Environment

  • PyTorch version: 63fc48257a02f8e28b79d13def7a7139589d4176 (Nov 2)
  • PyTorch/XLA version: d5d023063bfa8ecb4629f621f9b5890bc8396f58 (Nov 9)

Additional context

At first, this issue gets to the same error as #5719. However, after surfacing the error to tracing-time (instead of lowering-time), I get the following error:

Traceback (most recent call last):
  File "examples/bug-expand.py", line 15, in <module>
    print(foo(x))
  File "torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "examples/bug-expand.py", line 9, in foo
    @torch.compile(backend="openxla")
  File "torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "torch/_functorch/aot_autograd.py", line 4832, in forward
    return compiled_fn(full_args)
  File "torch/_functorch/aot_autograd.py", line 1948, in g
    return f(*args)
  File "torch/_functorch/aot_autograd.py", line 3154, in runtime_wrapper
    regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad)
  File "torch/_functorch/aot_autograd.py", line 842, in gen_alias_from_base
    aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)
RuntimeError: setStorage: sizes [2, 3], strides [3, 1], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 24

This looks really odd. Here, what I think AOTAutograd is thinking the result is a view. Then, it's trying to restride/reshape the input tensor to the specs of the output. In reality, however, the result is not a view because we call expand_copy_symint.

@ysiraichi
Copy link
Collaborator Author

cc @JackCaoG @miladm

@JackCaoG
Copy link
Collaborator

seems like another storage/stride issue?

@ysiraichi
Copy link
Collaborator Author

I don't think this is the case. I think the problem here is that, for some reason, AOTAutograd expects the output to be a view of the input. However, instead of calling expand, XLA calls expand_copy_symint, which is not a view.

@JackCaoG
Copy link
Collaborator

oh... I think expand is one of those ops that suppose to be view but we didn't implement it as a view.. Can you check if that's the case?

@ysiraichi
Copy link
Collaborator Author

Not sure. As far as I can see, given that XLA_DISABLE_FUNCTIONALIZATION is not set, aren't all operations (including the view ones) not implemented as view ops? I guess one thing I don't quite understand is: how are views implemented in XLA?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants