diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index a419f9335db..a40688a9164 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -240,6 +240,17 @@ def foo(x): x = torch.randint(0, 10, (10,), device=device) foo(x) + def test_inputs_not_computed(self): + + @torch.compile(backend="openxla") + def foo(x): + return x * 2 + + device = xm.xla_device() + x = torch.rand(5, device=device) + x = x.unsqueeze(dim=-1) + foo(x) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index d9c13c6ec69..e7a24a8e5f3 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -455,6 +455,12 @@ def call_module(self, target, args, kwargs): def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): + # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its + # value reference before actually computing it. + for a in xla_args: + if isinstance(a, torch.Tensor) and torch._is_functional_tensor(a): + torch._functionalize_sync(a) + # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids xm.mark_step()