Skip to content

Commit

Permalink
[Migration][DO NOT MERGE] Support submodule functions in pattern-rewrite
Browse files Browse the repository at this point in the history
From https://github.com/microsoft/onnx-rewriter/commit/d0e2876f2e6765738a1d9ad60b1d55000ffb50f7

Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Apr 4, 2024
1 parent fb3e327 commit c0d5b19
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 19 deletions.
4 changes: 4 additions & 0 deletions onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def output_names(self):
def attribute(self):
return self.original_node_proto.attribute

def set_version_if_custom_op(self, version_map: dict[str, int]) -> None:
if self.domain != "" and self.domain in version_map:
self.version = version_map[self.domain]

def get_attribute(self, name: str) -> int | float | None:
return self.attributes.get(name, None)

Expand Down
5 changes: 3 additions & 2 deletions onnxscript/_legacy_ir/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def visit_model(self, model_proto: onnx.ModelProto) -> ir.Model:
self._function_shape_env = visitor.FunctionShapeEnv()
self._function_shape_env.load_from_model_proto(model_proto)
self._ir_version = model_proto.ir_version
version_map = {x.domain: x.version for x in model_proto.opset_import}
self.version_map = {x.domain: x.version for x in model_proto.opset_import}
functions = [self.visit_function(function) for function in model_proto.functions]
self.functions = {function.id: function for function in functions}
graph = self.visit_graph(model_proto.graph)
model = ir.Model()
model.set(model_proto, graph, functions, version_map)
model.set(model_proto, graph, functions, self.version_map)
return model

def visit_graph(self, graph: onnx.GraphProto) -> ir.Graph:
Expand Down Expand Up @@ -122,6 +122,7 @@ def process_initializer(self, init: onnx.TensorProto):

def process_node(self, node):
node_ir = ir.Node(node)
node_ir.set_version_if_custom_op(self.version_map)
self.current_graph_or_function.nodes.append(node_ir)
for name in node.input:
value = self.lookup(name)
Expand Down
1 change: 1 addition & 0 deletions onnxscript/_legacy_ir/protobuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def visit_ir_function(
# function_proto.metadata_props = ir_function.original_function_proto.metadata_props)

for node in ir_function.nodes:
# TODO: deduplicate the opset import of function?
operator_setid_proto = function_proto.opset_import.add()
if node.domain in self.opset_imports:
operator_setid_proto.domain = self.opset_imports[node.domain].domain
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/broadcast_to_matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest
<ir_version: 7, opset_import: [ "" : 17, "pkg.custom": 1]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
{
output = afunction (input_x, input_y)
output = pkg.custom.afunction (input_x, input_y)
}
<domain: "pkg.custom", opset_import: [ "" : 17]>
afunction (input_x, input_y) => (output)
Expand Down
8 changes: 7 additions & 1 deletion onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import onnx

from onnxscript._legacy_ir import irbuilder, protobuilder
from onnxscript.optimizer import remove_unused
from onnxscript.optimizer import remove_unused, remove_unused_function
from onnxscript.rewriter import function_rule, pattern
from onnxscript.rewriter.onnxruntime import (
group_normalization_merge_silu,
instance_to_group_normalization,
softmax,
transformers,
Expand All @@ -16,6 +17,8 @@
ORT_PATTERN_REWRITE_RULES = [
*softmax.rules.rules,
*instance_to_group_normalization.rules.rules,
# NOTE: group normalization merge silu should be applied after instance to group normalization
*group_normalization_merge_silu.rules.rules,
]


Expand Down Expand Up @@ -49,5 +52,8 @@ def rewrite(
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model_ir)
print(f"Applied {count} pattern rewrite rules.")
model = protobuilder.build_model_proto(model_ir)
# TODO: Does it make more sense we run DCE after each rewrite rule applied?
# If so, we need IR to support DCE.
remove_unused.remove_unused_nodes(model)
remove_unused_function.remove_unused_functions(model)
return model
58 changes: 58 additions & 0 deletions onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import logging

from onnxscript.rewriter import pattern

op = pattern.onnxop
msft_op = pattern.msft_op
torch_module_op = pattern.torch_module_op

logger = logging.getLogger(__name__)


def group_normalization_and_silu_submodule(
input,
weight,
bias,
epsilon,
groups,
):
group_norm = msft_op.GroupNorm(
input,
weight,
bias,
activation=0,
channels_last=1,
epsilon=epsilon,
groups=groups,
)
transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2])
return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(transposed)


def group_normalization_with_silu(
input,
weight,
bias,
epsilon,
groups,
):
group_norm = msft_op.GroupNorm(
input,
weight,
bias,
activation=1,
channels_last=1,
epsilon=epsilon,
groups=groups,
)
return op.Transpose(group_norm, perm=[0, 3, 1, 2])


group_normalization_merge_silu_submodule_rule = pattern.RewriteRule(
group_normalization_and_silu_submodule,
group_normalization_with_silu,
)

rules = pattern.RewriteRuleSet([group_normalization_merge_silu_submodule_rule])
125 changes: 125 additions & 0 deletions onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import unittest

import numpy as np
import onnx.parser

from onnxscript._legacy_ir import irbuilder
from onnxscript.rewriter.onnxruntime import (
group_normalization_merge_silu,
instance_to_group_normalization,
)


class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase):
def test_group_norm_with_silu_submodule_is_replaced_by_group_norm(self):
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: ["" : 17, "pkg.torch230a0git77ef9d4" : 1, "com.microsoft" : 1]>
agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output)
{
group_norm = com.microsoft.GroupNorm <activation=0, channels_last=1, epsilon=0.000001, groups=32>(image, weight, bias)
transposed = Transpose <perm=[0, 3, 1, 2]>(group_norm)
output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed)
}
<domain: "pkg.torch230a0git77ef9d4", opset_import: ["" : 17]>
torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed) => (output)
{
_to_copy_38 = Cast <to: int = 1> (transposed)
sigmoid_18 = Sigmoid (_to_copy_38)
mul_26 = Mul (_to_copy_38, sigmoid_18)
output = Cast <to: int = 10> (mul_26)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
weight_value = np.random.rand(320, 1, 1).astype(np.float16)
bias_value = np.random.rand(320, 1, 1).astype(np.float16)
model.graph.initializer.extend(
[
onnx.helper.make_tensor(
"weight",
onnx.TensorProto.FLOAT16,
weight_value.shape,
weight_value,
),
onnx.helper.make_tensor(
"bias",
onnx.TensorProto.FLOAT16,
bias_value.shape,
bias_value,
),
]
)

ir = irbuilder.build_ir(model)
count = group_normalization_merge_silu.rules.apply_to_model(ir)
self.assertEqual(count, 1)
# plus 2 in model constants
self.assertEqual(len(ir.graph.nodes), 2)

def test_simulated_instance_norm_is_replaced_by_group_norm_silu(self):
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17, "pkg.torch230a0git77ef9d4" : 1]>
agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output)
{
adjusted_input_shape = Constant<value: tensor = int64[3] {0, 32, -1}>()
image_reshape = Reshape (image, adjusted_input_shape)
instance_norm = InstanceNormalization <epsilon=0.000001>(image_reshape, weight_for_norm, bias_for_norm)
original_input_shape = Constant<value: tensor = int64[4] {1, 320, 128, 128}>()
instance_norm_reshape = Reshape (instance_norm, original_input_shape)
mul_output = Mul (instance_norm_reshape, weight_full)
add_output = Add (mul_output, bias_full)
output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output)
}
<domain: "pkg.torch230a0git77ef9d4", opset_import: ["" : 17]>
torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output) => (output)
{
_to_copy_38 = Cast <to: int = 1> (add_output)
sigmoid_18 = Sigmoid (_to_copy_38)
mul_26 = Mul (_to_copy_38, sigmoid_18)
output = Cast <to: int = 10> (mul_26)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
weight_full_value = np.random.rand(320, 1, 1).astype(np.float16)
bias_full_value = np.random.rand(320, 1, 1).astype(np.float16)
weight_for_norm_value = np.ones(32, dtype=np.float16)
bias_for_norm_value = np.zeros(32, dtype=np.float16)

model.graph.initializer.extend(
[
onnx.helper.make_tensor(
"weight_for_norm",
onnx.TensorProto.FLOAT16,
weight_for_norm_value.shape,
weight_for_norm_value,
),
onnx.helper.make_tensor(
"bias_for_norm",
onnx.TensorProto.FLOAT16,
bias_for_norm_value.shape,
bias_for_norm_value,
),
onnx.helper.make_tensor(
"weight_full",
onnx.TensorProto.FLOAT16,
weight_full_value.shape,
weight_full_value,
),
onnx.helper.make_tensor(
"bias_full",
onnx.TensorProto.FLOAT16,
bias_full_value.shape,
bias_full_value,
),
]
)

ir = irbuilder.build_ir(model)
count = instance_to_group_normalization.rules.apply_to_model(ir)
count += group_normalization_merge_silu.rules.apply_to_model(ir)
self.assertEqual(count, 2)
# plus 2 in model constants
self.assertEqual(len(ir.graph.nodes), 10)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

op = pattern.onnxop
msft_op = pattern.msft_op
torch_module_op = pattern.torch_module_op

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -146,4 +147,6 @@ def group_normalization(
check_if_simulated_instance_norm_is_used,
)

# NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule,
# so we need to run instance_norm_to_group_norm_with_silu_rule first.
rules = pattern.RewriteRuleSet([instance_norm_to_group_norm_rule])
Loading

0 comments on commit c0d5b19

Please sign in to comment.