diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 1d657a5ab..83763a8ac 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -3,7 +3,7 @@ from __future__ import annotations import math -from typing import Callable +from typing import Callable, Sequence import numpy as np @@ -11,6 +11,21 @@ from onnxscript.optimizer import basic_constant_propagation +def display_nodes(nodes: Sequence[ir.Node]) -> None: + """Display a list of nodes in the order they appear in the graph.""" + if nodes: + graph = nodes[0].graph + if graph: + # Display nodes in same order as in graph: + # Currently doesn't handle (control-flow) subgraphs + for node in graph: + if node in nodes: + node.display() + else: + for node in nodes: + node.display() + + def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5) -> None: """Display the (backward or forward) subgraph from a given value or node upto a certain depth.""" slice = [] @@ -33,17 +48,7 @@ def visit(node: ir.Node, depth): visit(x, 0) elif isinstance(x, ir.Value) and x.producer() is not None: visit(x.producer(), 0) # type: ignore[arg-type] - if slice: - graph = slice[0].graph - if graph: - # Display nodes in same order as in graph: - # Currently doesn't handle (control-flow) subgraphs - for node in graph: - if node in slice: - node.display() - else: - for node in reversed(slice): - node.display() + display_nodes(slice) def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index a961ae872..333cb489d 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -5,9 +5,11 @@ import abc import contextlib import dataclasses +import enum import inspect import itertools import math +from collections import defaultdict from typing import ( Any, Callable, @@ -328,13 +330,17 @@ def __init__(self) -> None: self.outputs: list[ir.Value] = [] # For a failed match, _reason is a string that describes the reason for the failure. self._reason: str = "" + # Track the node that caused the failure. + # TODO: May be useful to extend this to be a collection of Nodes and Values. + self._failure_node: ir.Node | None = None def __bool__(self): return self._success - def fail(self, reason: str = "") -> MatchResult: + def fail(self, reason: str = "", node: ir.Node | None = None) -> MatchResult: self._success = False self._reason = reason + self._failure_node = node return self @property @@ -536,18 +542,23 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: We check the domain, op_type, and attributes of the node, but not the inputs. """ # TODO(rama): Ensure we handle "" and "onnx.ai" correctly. - if not self.domain.matches(node.domain): - return match.fail(f"Domain mismatch: expected {self.domain}, got {node.domain}.") if not self.op.matches(node.op_type): - return match.fail(f"OpType mismatch: expected {self.op}, got {node.op_type}.") + return match.fail( + f"OpType mismatch: expected {self.op}, got {node.op_type}.", node + ) + if not self.domain.matches(node.domain): + return match.fail( + f"Domain mismatch: expected {self.domain}, got {node.domain}.", node + ) for name, attr_pattern in self.attributes.items(): attr_value = node.attributes.get(name) if attr_value is None: - return match.fail(f"Attribute {name} not found in node.") + return match.fail(f"Attribute {name} not found in node.", node) if not attr_pattern.matches(attr_value): return match.fail( - f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}." + f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}.", + node, ) if attr_pattern.name is not None: if not match.bind(attr_pattern.name, attr_value): @@ -557,7 +568,7 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: for name in node.attributes: # TODO: Support matching default nodes for attributes. if name not in self.attributes: - return match.fail(f"Attribute {name} not expected in node.") + return match.fail(f"Attribute {name} not expected in node.", node) return match @@ -945,8 +956,10 @@ def match( model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, + *, verbose: int = 0, remove_nodes: bool = True, + tracer: MatchingTracer | None = None, ) -> MatchResult: """Match the pattern against the subgraph ending at the given node.""" @@ -957,13 +970,14 @@ def __str__(self) -> str: class SimplePatternMatcher(PatternMatcher): def __init__(self, pattern: GraphPattern) -> None: super().__init__(pattern) + self._current_node: ir.Node | None = None - def fail(self, reason: str) -> bool: + def fail(self, reason: str, node: ir.Node | None = None) -> bool: if self._verbose: if self._matched: # Print only if at least one node successfully matched. count = len(self._matched) print(f"Match failed after {count} nodes: {reason}") - self._match.fail(reason) + self._match.fail(reason, node or self._current_node) return False def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool: @@ -1025,7 +1039,7 @@ def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool: def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: """Matches a pattern subgraph against subgraph rooted at node.""" - + self._current_node = node # Graph-matching: we do not allow the same pattern node to be matched against # different graph nodes. if pattern_node in self._matched: @@ -1039,6 +1053,7 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: if self._verbose: print(f"Matched: {node.op_type}") + match.nodes.append(node) self._matched[pattern_node] = node # TODO: Revisit this to handle optional trailing inputs better. @@ -1067,7 +1082,6 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: if not self._bind_value(output_value_pattern, node.outputs[i]): return False - match.nodes.append(node) return True def _bind_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool: @@ -1115,6 +1129,7 @@ def _init_match(self, verbose: int) -> None: self._verbose = verbose self._matched: dict[NodePattern, ir.Node] = {} self._match: MatchResult = MatchResult() + self._current_node = None def _get_output_values(self) -> list[ir.Value] | None: """Get values bound to the output variables of the pattern.""" @@ -1163,8 +1178,10 @@ def _match_single_output_node( output_values = self._get_output_values() if output_values is None: + # TODO(rama): Is this a valid (useful) case? return match if check_removable and not _valid_to_replace(match.nodes, output_values): + # TODO(rama): Match status should be updated to reflect failure reason. return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) @@ -1200,8 +1217,10 @@ def match( model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, + *, verbose: int = 0, remove_nodes: bool = True, + tracer: MatchingTracer | None = None, ) -> MatchResult: """Match the pattern against the subgraph ending at the given node. @@ -1218,7 +1237,7 @@ def match( matching in the presence of subgraphs (control-flow) can introduce some complications which require careful consideration. """ - + self._tracer = tracer if self.pattern.has_single_output_node: self._init_match(verbose) return self._match_single_output_node( @@ -1268,6 +1287,8 @@ def __init__( verbose: int = 0, name: str | None = None, remove_nodes: bool = True, + graph_pre_visitor: Callable[[], None] | None = None, + graph_post_visitor: Callable[[], None] | None = None, ) -> None: """Create a rewrite rule. @@ -1284,6 +1305,10 @@ def __init__( verbose: The verbosity level of the rule. name: An optional name for the pattern that will show up in verbose logging. remove_nodes: If True, the matched nodes will be removed from the graph. + graph_pre_visitor: A function that will be called before applying the + rewriting to the top-level graph or a function. + graph_post_visitor: A function that will be called after the rewriting + is complete for a graph or function. """ if not isinstance(target_pattern, GraphPattern): @@ -1308,20 +1333,20 @@ def __init__( self._verbose = verbose self.name = name self.remove_nodes = remove_nodes + self.graph_pre_visitor = graph_pre_visitor + self.graph_post_visitor = graph_post_visitor def __str__(self) -> str: - if self.name: - return f"{self.__class__.__name__}(..., name={self.name!r})" - return ( - f"{self.__class__.__name__}({self._target_pattern}, {self._replacement_pattern})" - ) + return self.name if self.name else "Anonymous Rule" def try_rewrite( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, + *, verbose: int | None = None, + tracer: MatchingTracer | None = None, ) -> ReplacementSubgraph | None: """If the node matches the pattern, then replace the node with the replacement pattern.""" if verbose and verbose > 2: @@ -1337,9 +1362,17 @@ def try_rewrite( if var.name not in match.bindings: match.bindings[var.name] = None if not self._condition_function(context, **match.bindings): + if tracer: + tracer.log( + self, graph_or_function, node, match, MatchStatus.CONDITION_FAILED + ) return None replacement_subgraph = self._replacement_pattern.get_replacement(match) if replacement_subgraph is None: + if tracer: + tracer.log( + self, graph_or_function, node, match, MatchStatus.REPLACEMENT_FAILED + ) return None if len(replacement_subgraph.new_outputs) != self._target_pattern.num_outputs: raise ValueError( @@ -1349,15 +1382,26 @@ def try_rewrite( # TODO(rama): Remove the opset imports from deleted nodes? _update_opset_imports(graph_or_function, replacement_subgraph) _update_opset_imports(model.graph, replacement_subgraph) + if tracer: + tracer.log(self, graph_or_function, node, match, MatchStatus.SUCCESS) return replacement_subgraph + if tracer: + tracer.log(self, graph_or_function, node, match, MatchStatus.NO_MATCH) return None def apply_to_model( - self, model: ir.Model, *, commute: bool = False, verbose: int | None = None + self, + model: ir.Model, + *, + commute: bool = False, + verbose: int | None = None, + debug: bool = False, ): # A convenience method to apply the rule to a model. We use a RewriteRuleSet to # handle commutative rules. - return RewriteRuleSet([self], commute=commute).apply_to_model(model, verbose=verbose) + return RewriteRuleSet([self], commute=commute).apply_to_model( + model, verbose=verbose, debug=debug + ) def commute(self) -> Sequence[RewriteRule]: def replace_pattern(new_pattern): @@ -1370,6 +1414,10 @@ def replace_pattern(new_pattern): self._condition_function, matcher_class(new_pattern), self._verbose, + self.name, + self.remove_nodes, + self.graph_pre_visitor, + self.graph_post_visitor, ) return [replace_pattern(p) for p in self._target_pattern.commute()] @@ -1451,12 +1499,16 @@ class RewriteRuleClassBase: @classmethod def rule(cls, *args, **kwargs): instance = cls(*args, **kwargs) + setup = instance.setup if hasattr(instance, "setup") else None + cleanup = instance.cleanup if hasattr(instance, "cleanup") else None return RewriteRule( instance.pattern, instance.rewrite, instance.check, name=instance.name, remove_nodes=instance.remove_nodes, + graph_pre_visitor=setup, + graph_post_visitor=cleanup, ) def __init__(self, name: str | None = None, remove_nodes: bool = True) -> None: @@ -1484,16 +1536,34 @@ def _apply_to_graph_or_function( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, + *, verbose: int | None, + tracer: MatchingTracer | None = None, ) -> int: + """ + Apply the rewrite rules to the given graph or function. + + Args: + model: The model to which the rewrite rules are applied. + graph_or_function: The graph or function to which the rewrite rules are applied. + verbose: The verbosity level. Defaults to None. + tracer: The tracer for debugging. Defaults to None. + + Returns: + The number of rewrite rules applied. + """ count = 0 # NOTE: Rules should be prioritized in the order they are added to the RewriteRuleSet. # And the graph is applied in order. for rule in self.rules: + if rule.graph_pre_visitor: + rule.graph_pre_visitor() for node in graph_or_function: - delta = rule.try_rewrite(model, graph_or_function, node, verbose=verbose) - if delta is None: + delta = rule.try_rewrite( + model, graph_or_function, node, verbose=verbose, tracer=tracer + ) + if delta is None or tracer is not None: continue assert isinstance(delta, ReplacementSubgraph) # TODO: This does not yet handle the problem of determining the correct insertion point @@ -1510,17 +1580,115 @@ def _apply_to_graph_or_function( delta.new_outputs, ) count += 1 + if rule.graph_post_visitor: + rule.graph_post_visitor() return count - def apply_to_model(self, model: ir.Model, verbose: int | None = None) -> int: + def apply_to_model( + self, model: ir.Model, *, verbose: int | None = None, debug: bool = False + ) -> int: + """Apply the rewrite rules in the set to the model. + + Args: + model: The model to which the rewrite rules are applied. + verbose: The verbosity level of messages. Defaults to None. + debug: Whether to enable debugging. Defaults to False. In the + debug mode, no changes are made to the model, only a report is produced at + the end about the best matches found. + + Returns: + The number of applications of rewrite rules. + """ assert isinstance(model, ir.Model) + tracer = MatchingTracer() if debug else None onnxscript.optimizer.basic_constant_propagation(model.graph) - count = self._apply_to_graph_or_function(model, model.graph, verbose=verbose) + count = self._apply_to_graph_or_function( + model, model.graph, verbose=verbose, tracer=tracer + ) for function in model.functions.values(): onnxscript.optimizer.basic_constant_propagation(function) - count += self._apply_to_graph_or_function(model, function, verbose=verbose) + count += self._apply_to_graph_or_function( + model, function, verbose=verbose, tracer=tracer + ) + if tracer: + tracer.report() return count def __iter__(self): yield from self.rules + + +class MatchStatus(enum.IntEnum): + """The status of a pattern-matching operation.""" + + NO_MATCH = 0 # No successful match found for entire pattern graph + CONDITION_FAILED = 1 # Subsequent validation check failed + REPLACEMENT_FAILED = 2 # Replacement subgraph could not be created + SUCCESS = 3 # A successful match was found + + +@dataclasses.dataclass +class MatchInfo: + """The status of a pattern-matching operation. An extension of MatchResult.""" + + match_result: MatchResult + root_node: ir.Node + container: ir.Graph | ir.Function + status: MatchStatus + + def score(self) -> int: + """Return a score for the match.""" + return len(self.match_result.nodes) + int(self.status.value) * 100 + + +class MatchingTracer: + """A debugging helper class to trace the matching of a pattern against a graph. + + This is used to track the best matches found for each rule, and to report the + results at the end of the matching. + """ + + def __init__(self) -> None: + self._log: dict[RewriteRule, list[MatchInfo]] = defaultdict(list) + + def log( + self, + rule: RewriteRule, + container: ir.Graph | ir.Function, + node: ir.Node, + match_result: MatchResult, + status: MatchStatus, + ) -> None: + this_match = MatchInfo(match_result, node, container, status) + this_score = this_match.score() + if this_score == 0: + return + best_matches = self._log[rule] + if best_matches: + if this_score < best_matches[0].score(): + return + if this_score > best_matches[0].score(): + best_matches.clear() + best_matches.append(this_match) + + def report(self) -> None: + import onnxscript.rewriter._ir_utils as ir_utils + + print("===") + for rule, matches in self._log.items(): + if not matches: + continue + print(f"Rule: {rule}") + print(f"Best score: {matches[0].score()}") + for match in matches: + print(f"Status: {match.status}") + if match.status == MatchStatus.NO_MATCH: + print("Graph matching failed: " + match.match_result.reason) + node = match.match_result._failure_node + if node: + print("Failure at or around node:") + node.display() + print("Matched nodes:") + ir_utils.display_nodes(match.match_result.nodes) + print("===") diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0247949f5..1803ab670 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -476,6 +476,73 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: self.assertEqual(model.graph.node(0).op_type, "ReplacedNone") self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone") + def test_graph_visitor(self): + class ReplaceFoo(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__() + self.replacement = None + + def pattern(self, op): + return op.Foo() + + def rewrite(self, op): + if self.replacement is None: + self.replacement = op.Bar() + return self.replacement + + rule = ReplaceFoo.rule() + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should match following call + t1 = op.Foo() + # as well as this one + t2 = op.Foo() + z = op.Add(t1, t2) + return z + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 2) + self.assertEqual(len(model.graph), 2) + self.assertEqual(model.graph.node(0).op_type, "Bar") + self.assertEqual(model.graph.node(1).op_type, "Add") + + def test_debug_mode(self): + def source_pattern(op, x): + t1 = op.Abs(x) + t2 = op.Neg(t1) + t3 = op.Exp(t2) + return t3 + + def replacement(op, x): + return op.Something(x) + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + a2 = op.Abs(x) # match-1 fails here + a3 = op.Exp(a2) # match-1 starts here + b1 = op.Neg(a3) # match-2 fails here + b2 = op.Neg(b1) # match-2 (partially) succeeds here + b3 = op.Exp(b2) # match-2 starts here + return b3 + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + output_buffer = io.StringIO() + with contextlib.redirect_stdout(output_buffer): + count = rule.apply_to_model(model, debug=True) + captured_output = output_buffer.getvalue() + + self.assertEqual(count, 0) + # Not a robust test. But test serves to ensure that debug mode is producing something. + self.assertIn("OpType mismatch: expected Abs, got Neg", captured_output) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self):