From 67d00adb6c6653cfb9d7788800185d573a62954c Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 16 Oct 2023 19:13:19 -0300 Subject: [PATCH] Filter tensor arguments from traced model. (#5689) * Filter tensor arguments from traced model. This PR filters tensor arguments from the list of arguments that would be given to the model. **Problem:** dynamo bridge assumed all arguments were tensors. **Solution:** filter tensor arguments so that we correctly collect tensor information. * Add test. * Fix lint issues. * Simplified test. * Use `openxla` instead of `openxla_eval` backend. * Rename variables for readability. * Use `openxla_eval` instead of `openxla`. --- test/dynamo/test_bridge.py | 18 ++++++++++++++++++ torch_xla/core/dynamo_bridge.py | 29 ++++++++++++++++++----------- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 778d77591e4..8ca735073be 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -207,6 +207,24 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase): test_training_linear = make_training_test(LinearModule) test_training_maxpool = make_training_test(MaxPoolModule) + def test_non_tensor_args_for_partition(self): + + class Emb(torch.nn.Embedding): + + def __init__(self): + super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0) + + device = xm.xla_device() + module = Emb() + module.to(device) + + @torch.compile(backend="openxla_eval") + def foo(x): + return module(x) + + x = torch.randint(0, 10, (10,), device=device) + foo(x) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index f30f7d8ef8f..d9c13c6ec69 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -230,16 +230,19 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): for xla_arg in xla_args ] - args_tensor_ids = [ - torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args - ] + index_and_xla_tensor_args = [(i, xla_arg) + for i, xla_arg in enumerate(xla_args) + if isinstance(xla_arg, torch.Tensor)] + + index_and_tensor_ids = [(index, torch_xla._XLAC._xla_get_tensor_id(xla_arg)) + for index, xla_arg in index_and_xla_tensor_args] if dynamo_debug: print(f"Graph module:\n{xla_model.code}") - print(f"args_tensor_ids {args_tensor_ids}") + print(f"args_tensor_ids {index_and_tensor_ids}") tensor_id_to_arg_idx = { - tensor_id: i for i, tensor_id in enumerate(args_tensor_ids) + tensor_id: index for index, tensor_id in index_and_tensor_ids } if xr.is_spmd(): @@ -258,15 +261,16 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): # If a arg is being in place updated by model, we need to include arg as part of the graph result. xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( - xla_args) + [tensor for _, tensor in index_and_xla_tensor_args]) xla_args_need_update = [] arg_index_to_need_update_index = {} for i, need_update in enumerate(xla_args_need_update_bool): # Don't add inplace updated argument to the list if it's already # being returned - if need_update and id(xla_args[i]) not in xla_out_ids: - arg_index_to_need_update_index[i] = len(xla_args_need_update) - xla_args_need_update.append(xla_args[i]) + index, tensor = index_and_xla_tensor_args[i] + if need_update and id(tensor) not in xla_out_ids: + arg_index_to_need_update_index[index] = len(xla_args_need_update) + xla_args_need_update.append(tensor) args_and_out = tuple(xla_args_need_update) + tuple(xla_out) @@ -325,7 +329,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): def extract_internal(xla_model: torch.fx.GraphModule): if dynamo_debug: for xla_arg in xla_model.xla_args: - print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg)) + if isinstance(xla_arg, torch.Tensor): + print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg)) xm.mark_step() (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, @@ -347,7 +352,9 @@ def optimized_mod(*args): # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. - if any(torch_xla._XLAC._check_tensor_need_materialization(args)): + if any( + torch_xla._XLAC._check_tensor_need_materialization( + [a for a in args if isinstance(a, torch.Tensor)])): xm.mark_step(wait=True) # If input sharding has changed from the previous program, dynamo current can