diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 946ae914b04..3a3eb3d43f1 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -668,6 +668,21 @@ def foo(x): self.assertEqual(expected.dtype, actual.dtype) self.assertEqual(expected.device, actual.device) + def test_return_expand(self): + + def foo(x): + return x.expand(2, -1) + + optfoo = torch.compile(backend="openxla")(foo) + + t = torch.arange(10) + Xt = t.to(xm.xla_device()) + + expected = foo(t) + actual = optfoo(Xt) + + self.assertEqual(expected, actual.cpu()) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index ebc0af6c7ad..d2c4e1a3aca 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -186,6 +186,12 @@ def _init_xla_lazy_backend(): # TODO @wonjoo come up with a long term fix in Dynamo. torch._dynamo.config.automatic_dynamic_shapes = False +# Activate view-replay on AOTAutograd. +# See: https://github.com/pytorch/pytorch/pull/124488 +import torch._functorch.config + +torch._functorch.config.view_replay_for_aliased_outputs = True + from .stablehlo import save_as_stablehlo, save_torch_model_as_stablehlo from .experimental import plugins