Skip to content

Commit

Permalink
Update base for Migrate pattern rewrite commit on "Skip full model sh…
Browse files Browse the repository at this point in the history
…ape inference if model > 2GB | feat(optimizer)"

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Apr 3, 2024
1 parent a4a4632 commit 00b07e6
Show file tree
Hide file tree
Showing 7 changed files with 1,884 additions and 45 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ exclude_patterns = [
'onnxscript/_legacy_ir/protobuilder.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
'onnxscript/ir/serde.py', # FIXME
'onnxrewriter/rewriter/pattern/generic_pattern_test.py', # FIXME
]
command = [
'python',
Expand Down
44 changes: 41 additions & 3 deletions onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ def __str__(self) -> str:
]
)

@property
def input_names(self) -> list[str]:
return [_.name for _ in self.original_graph_proto.input]

@property
def output_names(self) -> list[str]:
return [_.name for _ in self.original_graph_proto.output]


class Function:
def __init__(self, function_proto: onnx.FunctionProto):
Expand Down Expand Up @@ -272,15 +280,45 @@ def to_proto(self) -> onnx.AttributeProto:


class Node:
def __init__(self, node_proto: onnx.NodeProto) -> None:
def __init__(
self,
node_proto: onnx.NodeProto,
populate_io: bool = False,
) -> None:
self.original_node_proto = node_proto
self.domain: str = node_proto.domain
self.version: int | None = None
self.op_type: str = node_proto.op_type
self.inputs: list[Value | None] = []
self.outputs: list[Value | None] = []
if populate_io:
self.inputs: list[Value | None] = [Value(i) for i in node_proto.input]
self.outputs: list[Value | None] = [Value(i) for i in node_proto.output]
else:
self.inputs: list[Value | None] = []
self.outputs: list[Value | None] = []
self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {}

def __repr__(self) -> str:
return (
f"{self.op_type}({','.join(self.original_node_proto.input)})"
f"->{','.join(self.original_node_proto.output)}"
)

@property
def name(self) -> str:
return self.original_node_proto.name

@property
def input_names(self):
return self.original_node_proto.input

@property
def output_names(self):
return self.original_node_proto.output

@property
def attribute(self):
return self.original_node_proto.attribute

def get_attribute(self, name: str) -> int | float | None:
return self.attributes.get(name, None)

Expand Down
6 changes: 3 additions & 3 deletions onnxscript/_legacy_ir/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def process_initializer(self, init: onnx.TensorProto):
def process_node(self, node):
node_ir = ir.Node(node)
self.current_graph_or_function.nodes.append(node_ir)
for input in node.input:
value = self.lookup(input)
for name in node.input:
value = self.lookup(name)
node_ir.inputs.append(value)
if value is not None:
value.uses.append(node_ir)
else:
# TODO(titaiwang): Do something more than warnings?
warnings.warn(f"Use of undefined variable '{input}'.", stacklevel=1)
warnings.warn(f"Use of undefined variable {name!r}.", stacklevel=1)
for index, output in enumerate(node.output):
newvalue = ir.Value(name=output, node=node_ir, output_index=index)
if self._current_function is not None:
Expand Down
Loading

0 comments on commit 00b07e6

Please sign in to comment.