diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 74a9c69b973..b7cb4db5e8c 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -266,6 +266,42 @@ def foo(device): self._compile_and_check(foo, (xm.xla_device(),)) + def test_index_flag_unsupported(self): + # The indices of the index operation are represented as + # a list of objects. If any non-XLA tensors appear, the + # index operation should be flagged as unsupported, since + # their arguments might be turned into placeholders of the + # partition FX graph. + + def foo(xt, t): + return xt[t] + + device = xm.xla_device() + xt = torch.rand(5, device=device) + t = torch.randint(0, 5, (3,)) + self._compile_and_check(foo, (xt, t)) + + def test_stack_flag_unsupported(self): + # Explicit list of tensors arguments. + + def foo(t): + return torch.stack([t]) + + t = torch.randint(0, 5, (3,)) + self._compile_and_check(foo, (t,)) + + def test_cpu_flag_unsupported(self): + # Nodes that return CPU tensors should also be flagged as + # unsupported, since their outputs could be turned into + # outputs of the partition FX graph. + + def foo(t): + return t.cpu() + + device = xm.xla_device() + t = torch.randint(0, 5, (3,), device=device) + self._compile_and_check(foo, (t,)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index e99624532fc..400c6aa388e 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -420,30 +420,51 @@ def optimized_mod(*args): return optimized_mod -class FallBackNodeCollector(torch.fx.Interpreter): +class UnsupportedNodesCollector(torch.fx.Interpreter): def __init__(self, module): super().__init__(module) - self._fallback_ops = [] + self._unsupported_nodes = [] def run_node(self, n: torch.fx.Node): metrics.clear_counters() result = super().run_node(n) fallback_ops = get_fallback_ops() if len(fallback_ops) > 0: - self._fallback_ops.append(n) + self._unsupported_nodes.append(n) else: - # if inputs are non-xla tensors, it should be executed on CPU - if n.op in ["call_function", "call_module", "call_method"]: - args, kwargs = self.fetch_args_kwargs_from_env(n) - for arg in args: - if isinstance(arg, torch.Tensor) and not is_xla_tensor(arg): - self._fallback_ops.append(n) - break + # Check whether the tensors contained in value are all XLA tensors. + def all_tensors_on_xla_device(value): + if isinstance(value, torch.Tensor): + return is_xla_tensor(value) + if isinstance(value, (list, tuple)): + return all(all_tensors_on_xla_device(v) for v in value) + # Not a tensor nor a container. + return True + + # Check whether the current node is supported or not. + # + # A supported node has the following characteristics: + # - a node whose result is a composition of XLA tensors: + # avoids non-XLA tensors as FX graph return value. + result_is_supported = all_tensors_on_xla_device(result) + + # - a node that whose tensor arguments are XLA tensors: + # avoids non-XLA tensors as FX graph arguments. + args, kwargs = self.fetch_args_kwargs_from_env(n) + args_are_supported = all( + all_tensors_on_xla_device(v) + for v in itertools.chain(args, kwargs.values())) + + # If the current node is NOT supported, we add it to + # the _unsupported_nodes list. + if not (result_is_supported and args_are_supported): + self._unsupported_nodes.append(n) + return result - def get_fallback_ops(self): - return self._fallback_ops + def get_unsupported_nodes(self): + return self._unsupported_nodes class InputCollector(torch.fx.Interpreter): @@ -518,11 +539,11 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): ] # execute model once to collect fallback ops - collector = FallBackNodeCollector(xla_model) + collector = UnsupportedNodesCollector(xla_model) collector.run(*xla_args) - fallback_ops = collector.get_fallback_ops() - if (ptxla_debug or dynamo_debug) and len(fallback_ops) > 0: - print('Dynamo fallback ops are' + str(fallback_ops) + + unsupported_nodes = collector.get_unsupported_nodes() + if (ptxla_debug or dynamo_debug) and len(unsupported_nodes) > 0: + print('Dynamo fallback ops are' + str(unsupported_nodes) + '. Please open a GitHub issue with the above op lowering requests.') # This logic, needed for supporting in-place operations, is a duplicate of @@ -545,7 +566,7 @@ class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: return node.op in [ "call_function", "call_module", "call_method" - ] and (node not in fallback_ops or node.target == operator.getitem) + ] and (node not in unsupported_nodes or node.target == operator.getitem) # partition the model supported_ops = XlaOperatorSupport()