Skip to content

Commit

Permalink
Rename FallBackNodeCollector into UnsupportedNodesCollector.
Browse files Browse the repository at this point in the history
- Add a check for the return value
- Add comments
  • Loading branch information
ysiraichi committed Jan 10, 2024
1 parent b6fa636 commit 65cd98a
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,30 +420,51 @@ def optimized_mod(*args):
return optimized_mod


class FallBackNodeCollector(torch.fx.Interpreter):
class UnsupportedNodesCollector(torch.fx.Interpreter):

def __init__(self, module):
super().__init__(module)
self._fallback_ops = []
self._unsupported_nodes = []

def run_node(self, n: torch.fx.Node):
metrics.clear_counters()
result = super().run_node(n)
fallback_ops = get_fallback_ops()
if len(fallback_ops) > 0:
self._fallback_ops.append(n)
self._unsupported_nodes.append(n)
else:
# if inputs are non-xla tensors, it should be executed on CPU
if n.op in ["call_function", "call_module", "call_method"]:
args, kwargs = self.fetch_args_kwargs_from_env(n)
for arg in args:
if isinstance(arg, torch.Tensor) and not is_xla_tensor(arg):
self._fallback_ops.append(n)
break
# Check whether the tensors contained in value are all XLA tensors.
def are_all_tensors_xla(value):
if isinstance(value, torch.Tensor):
return is_xla_tensor(value)
if isinstance(value, (list, tuple)):
return all(are_all_tensors_xla(v) for v in value)
# Not a tensor nor a container.
return True

# Check whether the current node is supported or not.
#
# A supported node has the following characteristics:
# - a node whose result is a composition of XLA tensors:
# avoids non-XLA tensors as FX graph return value.
result_is_supported = are_all_tensors_xla(result)

# - a node that whose tensor arguments are XLA tensors:
# avoids non-XLA tensors as FX graph arguments.
args, kwargs = self.fetch_args_kwargs_from_env(n)
args_are_supported = all(
are_all_tensors_xla(v)
for v in itertools.chain(args, kwargs.values()))

# If the current node is NOT supported, we add it to
# the _unsupported_nodes list.
if not (result_is_supported and args_are_supported):
self._unsupported_nodes.append(n)

return result

def get_fallback_ops(self):
return self._fallback_ops
def get_unsupported_nodes(self):
return self._unsupported_nodes


class InputCollector(torch.fx.Interpreter):
Expand Down Expand Up @@ -518,11 +539,11 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
]

# execute model once to collect fallback ops
collector = FallBackNodeCollector(xla_model)
collector = UnsupportedNodesCollector(xla_model)
collector.run(*xla_args)
fallback_ops = collector.get_fallback_ops()
if (ptxla_debug or dynamo_debug) and len(fallback_ops) > 0:
print('Dynamo fallback ops are' + str(fallback_ops) +
unsupported_nodes = collector.get_unsupported_nodes()
if (ptxla_debug or dynamo_debug) and len(unsupported_nodes) > 0:
print('Dynamo fallback ops are' + str(unsupported_nodes) +
'. Please open a GitHub issue with the above op lowering requests.')

# This logic, needed for supporting in-place operations, is a duplicate of
Expand All @@ -545,7 +566,7 @@ class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op in [
"call_function", "call_module", "call_method"
] and (node not in fallback_ops or node.target == operator.getitem)
] and (node not in unsupported_nodes or node.target == operator.getitem)

# partition the model
supported_ops = XlaOperatorSupport()
Expand Down

0 comments on commit 65cd98a

Please sign in to comment.