Skip to content

Commit

Permalink
Update on "[Migration][DO NOT MERGE] Fix linting"
Browse files Browse the repository at this point in the history
Cutoff from 991f82bcc99706cd8ea6b7f2b70ccec75db2dbd6.
Changes after the above commit but before the cutoff date will be migrated in follow-up within this stack.


[ghstack-poisoned]
  • Loading branch information
BowenBao committed Apr 3, 2024
2 parents 3032521 + 5adffb3 commit 695a6e2
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions onnxscript/optimizer/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def is_non_deterministic_op(node: onnx.NodeProto) -> bool:


def is_constant_op(node: onnx.NodeProto) -> bool:
return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain(
node.domain
)
return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain(node.domain)


class ConstantFolder(visitor.FunctionCallsiteProtoTransformer):
Expand Down Expand Up @@ -119,14 +117,10 @@ def new_constant(self, name, value):
info.type = onnx.helper.make_tensor_type_proto(
onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape
)
node = onnx.helper.make_node(
"Constant", inputs=[], outputs=[name], value=tensor
)
node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor)
return [node]

def convert_attributes(
self, attributes: Sequence[onnx.AttributeProto]
) -> dict[str, Any]:
def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]:
if self.scopes.current_scope().current_function_scope():
# Need to resolve ref_attr_name if inside a function.
attr_dict = {}
Expand All @@ -138,9 +132,7 @@ def convert_attributes(
)
if concrete_attribute is None:
continue
attr_dict[attribute.name] = onnx.helper.get_attribute_value(
concrete_attribute
)
attr_dict[attribute.name] = onnx.helper.get_attribute_value(concrete_attribute)
return attr_dict
return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes}

Expand Down Expand Up @@ -226,9 +218,7 @@ def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
self.add_count(op, outputs.size)
return replacement
else:
logger.warning(
"Skipping constant folding for op %s with multiple outputs.", op
)
logger.warning("Skipping constant folding for op %s with multiple outputs.", op)

Check warning on line 221 in onnxscript/optimizer/constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/constant_folding.py#L221

Added line #L221 was not covered by tests
return None

def process_function_node(
Expand All @@ -241,9 +231,7 @@ def process_function_node(
# Replace function node with Constant if all outputs are constants
ir_values = [self.lookup(output_name) for output_name in node.output]
tensors = [
self.foldable_value(
output_name, ir_value.value if ir_value is not None else None
)
self.foldable_value(output_name, ir_value.value if ir_value is not None else None)
for output_name, ir_value in zip(node.output, ir_values)
]
if all(tensor is not None for tensor in tensors):
Expand Down

0 comments on commit 695a6e2

Please sign in to comment.