Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-land: dynamo expand test with view-replay. #6958

Merged
merged 6 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,21 @@ def foo(x):
self.assertEqual(expected.dtype, actual.dtype)
self.assertEqual(expected.device, actual.device)

def test_return_expand(self):

def foo(x):
return x.expand(2, -1)

optfoo = torch.compile(backend="openxla")(foo)

t = torch.arange(10)
Xt = t.to(xm.xla_device())

expected = foo(t)
actual = optfoo(Xt)

self.assertEqual(expected, actual.cpu())


if __name__ == '__main__':
test = unittest.main()
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ def _init_xla_lazy_backend():
# TODO @wonjoo come up with a long term fix in Dynamo.
torch._dynamo.config.automatic_dynamic_shapes = False

# Activate view-replay on AOTAutograd.
# See: https://github.com/pytorch/pytorch/pull/124488
import torch._functorch.config

torch._functorch.config.view_replay_for_aliased_outputs = True

from .stablehlo import save_as_stablehlo, save_torch_model_as_stablehlo

from .experimental import plugins
Expand Down
Loading