Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unsqueeze returns wrong results when input is a sliced tensor with non-trivial offset #7859

Closed
yoavhacohen opened this issue Aug 15, 2024 · 4 comments · Fixed by #7864
Closed
Assignees

Comments

@yoavhacohen
Copy link

🐛 Bug

unsqueeze returns wrong results when input is a sliced tensor with non-trivial offset.
NOTE: This bug only happens when

To Reproduce

Run this from the command line:

export XLA_DISABLE_FUNCTIONALIZATION=1

Run this code:

import torch

try:
    # noinspection PyUnresolvedReferences
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.runtime as xr
    from torch_xla.utils import checkpoint as xla_checkpoint
    _torch_xla_available = True
except ImportError:
    _torch_xla_available = False

if _torch_xla_available:
    device = xm.xla_device()
else:
    assert torch.cuda.is_available()
    device = torch.device("cuda")
    
x = torch.randn(size=[2], device=device)
a1, b1 = x.chunk(2)
a2, b2 = x[0:1], x[1:2]
a3, b3 = x[0].unsqueeze(0), x[1].unsqueeze(0)
a4, b4= x[0, None], x[1, None]

def is_ok(a, b):
    return "OK" if a.squeeze() == x[0] and b.squeeze() == x[1] else "Not OK"

print(f"x: {x}")
print(f"a1, b1: {a1}, {b1}, a1.squeeze(), b1.squeeze(): {a1.squeeze()}, {b1.squeeze()} - {is_ok(a1, b1)}")
print(f"a2, b2: {a2}, {b2}, a2.squeeze(), b2.squeeze(): {a2.squeeze()}, {b2.squeeze()} - {is_ok(a2, b2)}")
print(f"a3, b3: {a3}, {b3}, a3.squeeze(), b3.squeeze(): {a3.squeeze()}, {b3.squeeze()} - {is_ok(a3, b3)}")
print(f"a4, b4: {a4}, {b4}, a4.squeeze(), b4.squeeze(): {a4.squeeze()}, {b4.squeeze()} - {is_ok(a4, b4)}")

The actual result is:

x: tensor([ 1.3556, -0.1277], device='xla:0')
a1, b1: tensor([1.3556], device='xla:0'), tensor([-0.1277], device='xla:0'), a1.squeeze(), b1.squeeze(): 1.3555867671966553, 1.3555867671966553 - Not OK
a2, b2: tensor([1.3556], device='xla:0'), tensor([-0.1277], device='xla:0'), a2.squeeze(), b2.squeeze(): 1.3555867671966553, 1.3555867671966553 - Not OK
a3, b3: tensor([1.3556], device='xla:0'), tensor([-0.1277], device='xla:0'), a3.squeeze(), b3.squeeze(): 1.3555867671966553, -0.12774525582790375 - OK
a4, b4: tensor([1.3556], device='xla:0'), tensor([-0.1277], device='xla:0'), a4.squeeze(), b4.squeeze(): 1.3555867671966553, -0.12774525582790375 - OK

Expected behavior

This is the expected result:

x: tensor([-0.5049, -0.0846], device='cuda:0')
a1, b1: tensor([-0.5049], device='cuda:0'), tensor([-0.0846], device='cuda:0'), a1.squeeze(), b1.squeeze(): -0.5049338340759277, -0.08459508419036865 - OK
a2, b2: tensor([-0.5049], device='cuda:0'), tensor([-0.0846], device='cuda:0'), a2.squeeze(), b2.squeeze(): -0.5049338340759277, -0.08459508419036865 - OK
a3, b3: tensor([-0.5049], device='cuda:0'), tensor([-0.0846], device='cuda:0'), a3.squeeze(), b3.squeeze(): -0.5049338340759277, -0.08459508419036865 - OK
a4, b4: tensor([-0.5049], device='cuda:0'), tensor([-0.0846], device='cuda:0'), a4.squeeze(), b4.squeeze(): -0.5049338340759277, -0.08459508419036865 - OK

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 2.4 (probably in other versions as well)

Additional context

The bug is not reproducible with XLA_DISABLE_FUNCTIONALIZATION=0

@JackCaoG
Copy link
Collaborator

@tengyifei do you want to take a look at this one?

@JackCaoG JackCaoG self-assigned this Aug 15, 2024
@JackCaoG
Copy link
Collaborator

ok I can repo the issue with XLA_DISABLE_FUNCTIONALIZATION=1.. looking into it.

@JackCaoG
Copy link
Collaborator

@ysiraichi I know what the issue is, when we run with XLA_DISABLE_FUNCTIONALIZATION , the IR looks like this

x = torch.randn(size=[2]).to(device)
a1, b1 = x.chunk(2)
print(torch_xla._XLAC._get_xla_tensors_text([b1]))
print(torch_xla._XLAC._get_xla_tensors_hlo([b1]))

print(torch_xla._XLAC._get_xla_tensors_text([b1.squeeze()]))
print(torch_xla._XLAC._get_xla_tensors_hlo([b1.squeeze()]))
with funct
IR {
  %0 = f32[2]{0} xla::device_data(), xla_shape=f32[2]{0}
  %1 = (f32[1]{0}, f32[1]{0}) aten::split(%0), num_outputs=2, xla_shape=(f32[1]{0}, f32[1]{0}), ROOT=0
}

IR {
  %0 = f32[2]{0} xla::device_data(), xla_shape=f32[2]{0}
  %1 = (f32[1]{0}, f32[1]{0}) aten::split(%0), num_outputs=2, xla_shape=(f32[1]{0}, f32[1]{0})
  %2 = f32[] aten::view(%1.1), xla_shape=f32[], ROOT=0
}

without funct 
IR {
  %0 = f32[2]{0} xla::device_data(), xla_shape=f32[2]{0}
  %1 = f32[1]{0} xla::select(%0), xla_shape=f32[1]{0}, ROOT=0
}

IR {
  %0 = f32[2]{0} xla::device_data(), xla_shape=f32[2]{0}
  %1 = f32[] aten::as_strided(%0), xla_shape=f32[], ROOT=0
}

Note that in the without funct case, it is a select following a as_strided. You added the logic in #5914 where you set the as_strided to operate in the base tensor of the view instead of the view tensor, but at least in this case the view op should happens on the b1, not the x. I am going to fix this issue but can you verifiy if the assumption that as_strided should happened on the storage is correct(or if it is only correct in the functionization case)?

@ysiraichi
Copy link
Collaborator

Yes. as_strided should be run on the storage (reference).

That said, this is a known problem with as_strided on XLA (I tried fixing it here). See the following example:

>>> x = torch.arange(5, device=dev)
# x = [1, 2, 3, 4, 5]

>>> a = x[2:]
# a = [3, 4, 5]

>>> a.as_strided((2,), (1,))
# Expected: [3, 4]
# Output: [1, 2]

The specific cause of this problem is that the storage_offset argument is not given. Whenever this happens, the value used is self.storage_offset(). However, since there's no notion of views in PyTorch/XLA, it is 0 for every XLA tensor.

This issue is complementary to the failures you are facing in your PR:

>>> x = torch.arange(5, device=dev)
# x = [1, 2, 3, 4, 5]

>>> a = x[2:]
# a = [3, 4, 5]

>>> a.as_strided((4,), (1,), storage_offset=0)
# Expected: [1, 2, 3, 4]
# Output: Error: we want 4 elements, but 'a' only has 3.

In the end, we will face at least one of these problems, unless we keep track of the storage_offset of tensor views.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants