-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
unsqueeze returns wrong results when input is a sliced tensor with non-trivial offset #7859
Comments
@tengyifei do you want to take a look at this one? |
ok I can repo the issue with |
@ysiraichi I know what the issue is, when we run with
Note that in the |
Yes. That said, this is a known problem with >>> x = torch.arange(5, device=dev)
# x = [1, 2, 3, 4, 5]
>>> a = x[2:]
# a = [3, 4, 5]
>>> a.as_strided((2,), (1,))
# Expected: [3, 4]
# Output: [1, 2] The specific cause of this problem is that the This issue is complementary to the failures you are facing in your PR: >>> x = torch.arange(5, device=dev)
# x = [1, 2, 3, 4, 5]
>>> a = x[2:]
# a = [3, 4, 5]
>>> a.as_strided((4,), (1,), storage_offset=0)
# Expected: [1, 2, 3, 4]
# Output: Error: we want 4 elements, but 'a' only has 3. In the end, we will face at least one of these problems, unless we keep track of the |
🐛 Bug
unsqueeze returns wrong results when input is a sliced tensor with non-trivial offset.
NOTE: This bug only happens when
To Reproduce
Run this from the command line:
export XLA_DISABLE_FUNCTIONALIZATION=1
Run this code:
The actual result is:
Expected behavior
This is the expected result:
Environment
Additional context
The bug is not reproducible with XLA_DISABLE_FUNCTIONALIZATION=0
The text was updated successfully, but these errors were encountered: