Skip to content

Commit

Permalink
A couple of extensions to rewriter (#2001)
Browse files Browse the repository at this point in the history
Extends the rewriter with a couple of features:

* A debugging mode to perform the pattern matching (without any graph
modifications) and to report instances that get the best score for a
match (even if incomplete). Helps quickly identify causes for mismatches
when we expect a match.

* Rewrite-rules can now specify a pre/post visitor method called before
applying it to a graph/function. This is useful for rules that need to
create "cached" values that are reused across multiple instances of the
pattern.

---------

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
  • Loading branch information
gramalingam and justinchuby authored Jan 9, 2025
1 parent a942e95 commit c2103e7
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 36 deletions.
29 changes: 17 additions & 12 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,29 @@
from __future__ import annotations

import math
from typing import Callable
from typing import Callable, Sequence

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import basic_constant_propagation


def display_nodes(nodes: Sequence[ir.Node]) -> None:
"""Display a list of nodes in the order they appear in the graph."""
if nodes:
graph = nodes[0].graph
if graph:
# Display nodes in same order as in graph:
# Currently doesn't handle (control-flow) subgraphs
for node in graph:
if node in nodes:
node.display()
else:
for node in nodes:
node.display()


def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5) -> None:
"""Display the (backward or forward) subgraph from a given value or node upto a certain depth."""
slice = []
Expand All @@ -33,17 +48,7 @@ def visit(node: ir.Node, depth):
visit(x, 0)
elif isinstance(x, ir.Value) and x.producer() is not None:
visit(x.producer(), 0) # type: ignore[arg-type]
if slice:
graph = slice[0].graph
if graph:
# Display nodes in same order as in graph:
# Currently doesn't handle (control-flow) subgraphs
for node in graph:
if node in slice:
node.display()
else:
for node in reversed(slice):
node.display()
display_nodes(slice)


def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
Expand Down
Loading

0 comments on commit c2103e7

Please sign in to comment.