Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Migration][DO NOT MERGE] Restructure rewriter core under onnxscript #1329

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 296 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""In-memory intermediate representation for ONNX graphs."""
from __future__ import annotations

__all__ = [
# Modules
Expand Down Expand Up @@ -107,3 +108,298 @@
TypeProtocol,
ValueProtocol,
)

import dataclasses

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

standard import "import dataclasses" should be placed before "from onnxscript.ir import serde" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order
from collections import deque

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

standard import "from collections import deque" should be placed before "from onnxscript.ir import serde" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order
from typing import List, Tuple, Union

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

standard import "from typing import List, Tuple, Union" should be placed before "from onnxscript.ir import serde" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order

import numpy as np

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

third party import "import numpy as np" should be placed before "from onnxscript.ir import serde" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order
import onnx

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

third party import "import onnx" should be placed before "from onnxscript.ir import serde" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order


class Unknown:
"""A special value used to indicate that a value is not a statically known constant.

We use this instead of None because None is a valid constant value (since ONNX
supports the Optional type).
"""

instance = None

def __init__(self) -> None:
if Unknown.instance is not None:
raise ValueError("Unknown.instance is already set")

Check warning on line 131 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L131

Added line #L131 was not covered by tests
Unknown.instance = self


# Singleton instance of Unknown
unknown = Unknown()
NotConstant = unknown

# ConcreteValue: This type represents constant values that an ONNX variable can take.
# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals,
# maps, etc.
# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto.
# A uniform representation would be helpful, but we should avoid unnecessary conversions for
# large tensors. Should be cleaned up in the new IR.
ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None]

# SymbolicValue: This information is used to enable partial-evaluation and specialization
# of sequence operations, as well as elimination of redundant Identity ops.
# The symbolic value of a variable X can be:
# - a string with the value "Y", indicating that "X" is a copy of "Y"
# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values
# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to
# "SequenceConstruct(A, B, C)".
# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of
# tensors, etc. However, we currently only handle lists of tensors.

SymbolicValue = Union[str, List[str]]

FunctionId = Tuple[str, str, str]


def get_function_id(function: onnx.FunctionProto) -> FunctionId:
return (function.domain, function.name, getattr(function, "overload", ""))

Check warning on line 163 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L163

Added line #L163 was not covered by tests


def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId:
return (node.domain, node.op_type, getattr(node, "overload", ""))

Check warning on line 167 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L167

Added line #L167 was not covered by tests


@dataclasses.dataclass
class StaticValueInfo:
name: str
value: ConcreteValue = NotConstant
type: onnx.TypeProto | None = None
symbolic_value: SymbolicValue | None = None

def is_copy(self) -> bool:
return isinstance(self.symbolic_value, str)

Check warning on line 178 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L178

Added line #L178 was not covered by tests

def tensor_shape_proto(self) -> onnx.TensorShapeProto | None:
"""Returns the shape of a tensor or None.

A return value of None could mean that the type is unknown or that the type is not a tensor
or that the tensor shape (that is, even the rank) is unknown.
"""
type = self.type

Check warning on line 186 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L186

Added line #L186 was not covered by tests
if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"):
return type.tensor_type.shape
return None

Check warning on line 189 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L188-L189

Added lines #L188 - L189 were not covered by tests

@property
def shape(self) -> list[str | int | None] | None:
"""Returns the shape in a list.

Str means that the shape is dynamic.
"""
type = self.type

Check warning on line 197 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L197

Added line #L197 was not covered by tests
if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"):
dims = []

Check warning on line 199 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L199

Added line #L199 was not covered by tests
for dim in type.tensor_type.shape.dim:
if dim.HasField("dim_param"):
dims.append(dim.dim_param)

Check warning on line 202 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L202

Added line #L202 was not covered by tests
elif dim.HasField("dim_value"):
dims.append(dim.dim_value)

Check warning on line 204 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L204

Added line #L204 was not covered by tests
else:
dims.append(None)
return dims

Check warning on line 207 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L206-L207

Added lines #L206 - L207 were not covered by tests
if self.value_as_np_array is not None:
return list(self.value_as_np_array.shape)
return None

Check warning on line 210 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L209-L210

Added lines #L209 - L210 were not covered by tests

@property
def element_type(self) -> int | None:
"""Returns the element type of a tensor, or None if type is not known or is not a tensor."""
type = self.type

Check warning on line 215 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L215

Added line #L215 was not covered by tests
if type and type.HasField("tensor_type"):
return type.tensor_type.elem_type
return None

Check warning on line 218 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L217-L218

Added lines #L217 - L218 were not covered by tests

def identity_merge_from(self, other: StaticValueInfo) -> None:
"""Merge the value of other into self.

This models the effect of an identity (copy) operation.
This will update static-analysis information based on incoming value.
"""
if not isinstance(other, StaticValueInfo):
raise TypeError(f"Cannot merge {other} into {self}.")

Check warning on line 227 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L227

Added line #L227 was not covered by tests
if other.value is not NotConstant:
self.value = other.value

Check warning on line 229 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L229

Added line #L229 was not covered by tests
# TODO: merge and combine best shape information from both types.
if other.tensor_shape_proto() is not None and other.element_type is not None:
self.type = other.type

Check warning on line 232 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L232

Added line #L232 was not covered by tests
# We cannot copy symbolic value across different scopes.

# WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo
# does not fill in the following fields. These fields are filled in by the IRBuilder
# which constructs the IR from the ONNX model.
node: Node | None = None
uses: list[Node] = dataclasses.field(default_factory=list)
output_index: int | None = None
is_output: bool = False

@property
def const_value(self) -> ConcreteValue:
return self.value

Check warning on line 245 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L245

Added line #L245 was not covered by tests

@property
def value_as_np_array(self) -> np.ndarray | None:
if isinstance(self.value, np.ndarray):
return self.value

Check warning on line 250 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L250

Added line #L250 was not covered by tests
if isinstance(self.value, onnx.TensorProto):
return onnx.numpy_helper.to_array(self.value)
return None

Check warning on line 253 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L252-L253

Added lines #L252 - L253 were not covered by tests

def def_node(self) -> Node | None:
return self.node

Check warning on line 256 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L256

Added line #L256 was not covered by tests

def def_index(self) -> int:
return self.output_index

Check warning on line 259 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L259

Added line #L259 was not covered by tests

def is_same_as(self, other: StaticValueInfo) -> bool:
"""Returns true if this value represents the same IR object as the other value.

This is *not* value-equality, but rather object-equality.
"""
return self is other

Check warning on line 266 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L266

Added line #L266 was not covered by tests

def __str__(self) -> str:
shape = self.shape

Check warning on line 269 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L269

Added line #L269 was not covered by tests
if shape is not None:
shape = [str(dim) for dim in shape]
shape_str = f"[{', '.join(shape)}]"

Check warning on line 272 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L272

Added line #L272 was not covered by tests
else:
shape_str = "None"
return (

Check warning on line 275 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L274-L275

Added lines #L274 - L275 were not covered by tests
f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, "
f"{'has const value' if self.value is not unknown else 'no const value'}.)"
)


Value = StaticValueInfo

Check warning

Code scanning / lintrunner

RUFF/F811 Warning

Redefinition of unused Value from line 88.
See https://docs.astral.sh/ruff/rules/redefined-while-unused


class Model:

Check warning

Code scanning / lintrunner

RUFF/F811 Warning

Redefinition of unused Model from line 79.
See https://docs.astral.sh/ruff/rules/redefined-while-unused

Check failure

Code scanning / lintrunner

PYLINT/E0102 Error

class already defined line 59 (function-redefined)
See function-redefined. To disable, use # pylint: disable=function-redefined
def __init__(self) -> None:
self.gen_var_counter: int = 0

Check warning on line 286 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L286

Added line #L286 was not covered by tests

def set(
self,
model_proto: onnx.ModelProto,
graph: Graph,
functions: list[Function],
version_map: dict[str, int],
) -> None:
"""TODO. This is a temporary patch."""
self.original_model_proto = model_proto
self.graph = graph
self.functions = functions
self.version_map = version_map

Check warning on line 299 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L296-L299

Added lines #L296 - L299 were not covered by tests

def make_new_name(self):
# Temporary hack.
self.gen_var_counter += 1
return f"_gen_{self.gen_var_counter}"

Check warning on line 304 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L303-L304

Added lines #L303 - L304 were not covered by tests

def __str__(self) -> str:
# TODO: Naive string representation for debugging. Need to improve this.
return "\n".join(

Check warning on line 308 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L308

Added line #L308 was not covered by tests
[
f"ModelGraph: {self.graph}",
f"Functions: {self.functions}",
f"VersionMap: {self.version_map}",
]
)


class Graph:

Check warning

Code scanning / lintrunner

RUFF/F811 Warning

Redefinition of unused Graph from line 77.
See https://docs.astral.sh/ruff/rules/redefined-while-unused

Check failure

Code scanning / lintrunner

PYLINT/E0102 Error

class already defined line 59 (function-redefined)
See function-redefined. To disable, use # pylint: disable=function-redefined
def __init__(self, graph_proto: onnx.GraphProto):
self.original_graph_proto = graph_proto
self.nodes: deque[Node] = deque()
self.values: dict[str, Value] = {}

Check warning on line 321 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L319-L321

Added lines #L319 - L321 were not covered by tests

@property
def name(self) -> str:
return self.original_graph_proto.name

Check warning on line 325 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L325

Added line #L325 was not covered by tests

def __str__(self) -> str:
return "\n".join(
[
"Graph",
f"Nodes: {[str(n) for n in self.nodes]}",
f"Values: {[str(v) for v in self.values]}",
]
)


class Function:

Check warning

Code scanning / lintrunner

RUFF/F811 Warning

Redefinition of unused Function from line 76.
See https://docs.astral.sh/ruff/rules/redefined-while-unused

Check failure

Code scanning / lintrunner

PYLINT/E0102 Error

class already defined line 59 (function-redefined)
See function-redefined. To disable, use # pylint: disable=function-redefined
def __init__(self, function_proto: onnx.FunctionProto):
self.original_function_proto = function_proto
self.nodes = deque()
self.values = {}

Check warning on line 341 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L339-L341

Added lines #L339 - L341 were not covered by tests

@property
def id(self) -> FunctionId:
return (self.domain, self.name, self.overload)

Check warning on line 345 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L345

Added line #L345 was not covered by tests

@property
def domain(self) -> str:
return self.original_function_proto.domain

Check warning on line 349 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L349

Added line #L349 was not covered by tests

@property
def name(self) -> str:
return self.original_function_proto.name

Check warning on line 353 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L353

Added line #L353 was not covered by tests

@property
def overload(self) -> str:
return getattr(self.original_function_proto, "overload", "")

Check warning on line 357 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L357

Added line #L357 was not covered by tests

def __str__(self) -> str:
return "\n".join(
[
"Function",
f"Nodes: {[str(n) for n in self.nodes]}",
f"Values: {[str(v) for v in self.values]}",
]
)


class RefAttr:

Check warning

Code scanning / lintrunner

RUFF/F811 Warning

Redefinition of unused RefAttr from line 82.
See https://docs.astral.sh/ruff/rules/redefined-while-unused

Check failure

Code scanning / lintrunner

PYLINT/E0102 Error

class already defined line 59 (function-redefined)
See function-redefined. To disable, use # pylint: disable=function-redefined
def __init__(self, name: str, ref_attr_name: str, type) -> None:
self.name = name
self.ref_attr_name = ref_attr_name
self.type = type

Check warning on line 373 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L371-L373

Added lines #L371 - L373 were not covered by tests

def to_proto(self) -> onnx.AttributeProto:
attr_proto = onnx.AttributeProto()
attr_proto.name = self.name
attr_proto.ref_attr_name = self.ref_attr_name
attr_proto.type = self.type
return attr_proto

Check warning on line 380 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L376-L380

Added lines #L376 - L380 were not covered by tests


class Node:

Check warning

Code scanning / lintrunner

RUFF/F811 Warning

Redefinition of unused Node from line 80.
See https://docs.astral.sh/ruff/rules/redefined-while-unused

Check failure

Code scanning / lintrunner

PYLINT/E0102 Error

class already defined line 59 (function-redefined)
See function-redefined. To disable, use # pylint: disable=function-redefined
def __init__(self, node_proto: onnx.NodeProto) -> None:
self.original_node_proto = node_proto
self.domain: str = node_proto.domain
self.version: int | None = None
self.op_type: str = node_proto.op_type
self.inputs: list[Value | None] = []
self.outputs: list[Value | None] = []
self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {}

Check warning on line 391 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L385-L391

Added lines #L385 - L391 were not covered by tests

def get_attribute(self, name: str) -> int | float | None:
return self.attributes.get(name, None)

Check warning on line 394 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L394

Added line #L394 was not covered by tests

def __str__(self) -> str:
return "\n".join(

Check warning on line 397 in onnxscript/ir/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/__init__.py#L397

Added line #L397 was not covered by tests
[
"Node",
f"OpType: {self.op_type}",
f"Inputs: {self.inputs}",
f"Outputs: {self.outputs}",
f"Attributes: {self.attributes}",
]
)
File renamed without changes.
File renamed without changes.
Loading
Loading