Skip to content

Commit

Permalink
Sync xla_args before computation. (#5823)
Browse files Browse the repository at this point in the history
* Sync `xla_args` before actually computing them.

* Only sync `FunctionalTensorWrapper` tensors.
  • Loading branch information
ysiraichi authored and bhavya01 committed Apr 22, 2024
1 parent 1a6e265 commit 6e143bf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/dynamo/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@ def foo(x):
x = torch.randint(0, 10, (10,), device=device)
foo(x)

def test_inputs_not_computed(self):

@torch.compile(backend="openxla")
def foo(x):
return x * 2

device = xm.xla_device()
x = torch.rand(5, device=device)
x = x.unsqueeze(dim=-1)
foo(x)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ def call_module(self, target, args, kwargs):


def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
# Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its
# value reference before actually computing it.
for a in xla_args:
if isinstance(a, torch.Tensor) and torch._is_functional_tensor(a):
torch._functionalize_sync(a)

# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids
xm.mark_step()

Expand Down

0 comments on commit 6e143bf

Please sign in to comment.