From 6b62e8b60f93f005a2a92e1a16c2124c609e976b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 18 Nov 2023 15:18:38 -0300 Subject: [PATCH 1/2] Sync `xla_args` before actually computing them. --- test/dynamo/test_bridge.py | 11 +++++++++++ torch_xla/core/dynamo_bridge.py | 6 ++++++ 2 files changed, 17 insertions(+) diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index a419f9335db..a40688a9164 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -240,6 +240,17 @@ def foo(x): x = torch.randint(0, 10, (10,), device=device) foo(x) + def test_inputs_not_computed(self): + + @torch.compile(backend="openxla") + def foo(x): + return x * 2 + + device = xm.xla_device() + x = torch.rand(5, device=device) + x = x.unsqueeze(dim=-1) + foo(x) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index d9c13c6ec69..0b8f0b64dc7 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -455,6 +455,12 @@ def call_module(self, target, args, kwargs): def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): + # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its + # value reference before actually computing it. + for a in xla_args: + if isinstance(a, torch.Tensor): + torch._sync(a) + # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids xm.mark_step() From 99e696673ceb67bd8b0a1d9b4af0273adf765a60 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 18 Nov 2023 15:18:38 -0300 Subject: [PATCH 2/2] Only sync `FunctionalTensorWrapper` tensors. --- torch_xla/core/dynamo_bridge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 0b8f0b64dc7..e7a24a8e5f3 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -458,8 +458,8 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): # Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its # value reference before actually computing it. for a in xla_args: - if isinstance(a, torch.Tensor): - torch._sync(a) + if isinstance(a, torch.Tensor) and torch._is_functional_tensor(a): + torch._functionalize_sync(a) # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids xm.mark_step()