Skip to content

Commit

Permalink
Update base for Migrate pattern rewrite commit on "[Test] Hack to ena…
Browse files Browse the repository at this point in the history
…ble optimizer/rewriter integration into dynamo_export"

- To test in torchbench.
- Somehow lintrunner changed unrelated files in this commit.

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Apr 3, 2024
1 parent c48fdca commit 09ca34f
Show file tree
Hide file tree
Showing 7 changed files with 1,884 additions and 45 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ exclude_patterns = [
'onnxscript/_legacy_ir/protobuilder.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
'onnxscript/ir/serde.py', # FIXME
'onnxrewriter/rewriter/pattern/generic_pattern_test.py', # FIXME
]
command = [
'python',
Expand Down
44 changes: 41 additions & 3 deletions onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ def __str__(self) -> str:
]
)

@property
def input_names(self) -> list[str]:
return [_.name for _ in self.original_graph_proto.input]

@property
def output_names(self) -> list[str]:
return [_.name for _ in self.original_graph_proto.output]


class Function:
def __init__(self, function_proto: onnx.FunctionProto):
Expand Down Expand Up @@ -272,15 +280,45 @@ def to_proto(self) -> onnx.AttributeProto:


class Node:
def __init__(self, node_proto: onnx.NodeProto) -> None:
def __init__(
self,
node_proto: onnx.NodeProto,
populate_io: bool = False,
) -> None:
self.original_node_proto = node_proto
self.domain: str = node_proto.domain
self.version: int | None = None
self.op_type: str = node_proto.op_type
self.inputs: list[Value | None] = []
self.outputs: list[Value | None] = []
if populate_io:
self.inputs: list[Value | None] = [Value(i) for i in node_proto.input]
self.outputs: list[Value | None] = [Value(i) for i in node_proto.output]
else:
self.inputs: list[Value | None] = []
self.outputs: list[Value | None] = []
self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {}

def __repr__(self) -> str:
return (
f"{self.op_type}({','.join(self.original_node_proto.input)})"
f"->{','.join(self.original_node_proto.output)}"
)

@property
def name(self) -> str:
return self.original_node_proto.name

@property
def input_names(self):
return self.original_node_proto.input

@property
def output_names(self):
return self.original_node_proto.output

@property
def attribute(self):
return self.original_node_proto.attribute

def get_attribute(self, name: str) -> int | float | None:
return self.attributes.get(name, None)

Expand Down
6 changes: 3 additions & 3 deletions onnxscript/_legacy_ir/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def process_initializer(self, init: onnx.TensorProto):
def process_node(self, node):
node_ir = ir.Node(node)
self.current_graph_or_function.nodes.append(node_ir)
for input in node.input:
value = self.lookup(input)
for name in node.input:
value = self.lookup(name)
node_ir.inputs.append(value)
if value is not None:
value.uses.append(node_ir)
else:
# TODO(titaiwang): Do something more than warnings?
warnings.warn(f"Use of undefined variable '{input}'.", stacklevel=1)
warnings.warn(f"Use of undefined variable {name!r}.", stacklevel=1)
for index, output in enumerate(node.output):
newvalue = ir.Value(name=output, node=node_ir, output_index=index)
if self._current_function is not None:
Expand Down
Loading

0 comments on commit 09ca34f

Please sign in to comment.