Skip to content

Commit

Permalink
Refactor builder out as an utility (#1772)
Browse files Browse the repository at this point in the history
Move the IR builder utility out as a separate utility.

TODO:
* Should this be moved into the `ir` folder?
* Should `Builder` be merged with the `Tape` class?
* Eventually, merge this with trace-mode onnxscript and expose it to end
users.
  • Loading branch information
gramalingam authored Aug 7, 2024
1 parent 3b95a44 commit 2dd69db
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 60 deletions.
49 changes: 46 additions & 3 deletions onnxscript/rewriter/_tape.py → onnxscript/ir/_tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Iterable, Mapping, Sequence
from typing import Any, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple

from onnxscript import ir
from onnxscript.ir import _convenience
Expand All @@ -19,8 +19,8 @@ class Tape(Iterable[ir.Node]):
def __init__(self) -> None:
self._nodes: list[ir.Node] = []

def __iter__(self) -> Sequence[ir.Node]:
return self._nodes
def __iter__(self) -> Iterator[ir.Node]:
return iter(self._nodes)

@property
def nodes(self) -> Sequence[ir.Node]:
Expand Down Expand Up @@ -59,3 +59,46 @@ def op_multi_output(
self._nodes.append(node)

return node.outputs


# A type representing the domains/versions used in creating nodes in IR.
UsedOpsets = List[Tuple[str, Optional[int]]]


class Builder(Tape):
"""An extension of the tape that provides a more convenient API for constructing the IR."""

def __init__(self):
super().__init__()
self._used_opsets: UsedOpsets = []

def __getattr__(self, op_type: str) -> Any:
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)

def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)
if isinstance(outputs, Sequence):
num_outputs = len(outputs)
else:
assert isinstance(outputs, int)
num_outputs = outputs

self._used_opsets.append((domain, version))
if num_outputs == 1:
value = super().op(op_type, inputs=inputs, attributes=kwargs, domain=domain)
if isinstance(outputs, Sequence):
value.name = outputs[0]
return value
values = super().op_multi_output(
op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs
)
if isinstance(outputs, Sequence):
for value, name in zip(values, outputs):
value.name = name
return values

@property
def used_opsets(self) -> UsedOpsets:
return self._used_opsets
61 changes: 4 additions & 57 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
Callable,
Iterable,
Iterator,
List,
MutableSequence,
Optional,
Protocol,
Sequence,
Tuple,
Expand All @@ -24,8 +22,8 @@
)

from onnxscript import ir
from onnxscript.ir import _convenience
from onnxscript.rewriter import _ir_utils, _tape
from onnxscript.ir import _convenience, _tape
from onnxscript.rewriter import _ir_utils

T = TypeVar("T")

Expand Down Expand Up @@ -818,58 +816,7 @@ def _valid_to_replace(
return True


# A type representing the domains/versions used in creating a replacement subgraph
UsedOpsets = List[Tuple[str, Optional[int]]]


class RewriterContext:
"""Context parameter used to build the replacement pattern."""

# TODO(justinchuby): Merge with the rest of pattern building methods
def __init__(self):
self._tape = _tape.Tape()
self._used_opsets: UsedOpsets = []

def __getattr__(self, op_type: str) -> Any:
return lambda *args, **kwargs: self._make_node(op_type, args, kwargs)

def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]):
# TODO(rama): some of the following logic should move into the tape.
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)
if isinstance(outputs, Sequence):
num_outputs = len(outputs)
else:
assert isinstance(outputs, int)
num_outputs = outputs

self._used_opsets.append((domain, version))
if num_outputs == 1:
value = self._tape.op(op_type, inputs=inputs, attributes=kwargs, domain=domain)
if isinstance(outputs, Sequence):
value.name = outputs[0]
return value
values = self._tape.op_multi_output(
op_type, inputs=inputs, attributes=kwargs, domain=domain, num_outputs=num_outputs
)
if isinstance(outputs, Sequence):
for value, name in zip(values, outputs):
value.name = name
return values

@property
def nodes(self) -> Sequence[ir.Node]:
# TODO(rama): The current tape-based implementation will not track nodes added
# via overloaded operators, eg., `x + y`. One possible way to fix this is to
# have values/nodes know which tape they belong to (instead of a graph/function).
# However, it is unclear we need this feature for rewriting: we could also
# identify the nodes to be inserted from the replacement values (by tracing back).
return self._tape.nodes

@property
def used_opsets(self) -> UsedOpsets:
return self._used_opsets
RewriterContext = _tape.Builder


@dataclasses.dataclass
Expand All @@ -879,7 +826,7 @@ class ReplacementSubgraph:
match: MatchResult
new_outputs: Sequence[ir.Value]
new_nodes: Sequence[ir.Node]
used_opsets: UsedOpsets
used_opsets: _tape.UsedOpsets


def always_true(*args, **kwargs) -> bool:
Expand Down

0 comments on commit 2dd69db

Please sign in to comment.