Skip to content

Commit

Permalink
Update lint rules; Migrate groupnorm commit on "Skip full model shape…
Browse files Browse the repository at this point in the history
… inference if model > 2GB | feat(optimizer)"

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Apr 4, 2024
2 parents 683d968 + e132e46 commit fc01bfa
Show file tree
Hide file tree
Showing 11 changed files with 450 additions and 19 deletions.
191 changes: 191 additions & 0 deletions examples/pattern_rewriting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Onnx Pattern Rewriting.
This script shows how to define a rewriting rule based on patterns.
The objective is to replace some nodes in an onnx model into another
sequence of nodes but more efficient.
First a dummy model
===================
"""

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh

import onnxscript
import onnxscript._legacy_ir as oir
import onnxscript.rewriter.generic_pattern as org


def get_rotary_model(bad_model=False):
inputs = [
oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]),
oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]),
oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]),
]
nodes = [
oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]),
oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1),
oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]),
oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]),
oh.make_node(
"ConcatTrainingBad" if bad_model else "ConcatTraining",
["_onx_transpose0", "_onx_transpose0"],
["_onx_concattraining0", "_onx_concattraining1"],
domain="com.microsoft",
),
oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]),
oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1),
oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]),
oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1),
]
outputs = [
oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []),
oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []),
]
model = oh.make_model(
oh.make_graph(
nodes,
"experiment",
inputs,
outputs,
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("com.microsoft", 18),
],
)
return model


model = get_rotary_model()
ir_model = oir.irbuilder.build_ir(model)


####################################
# The rewriting pattern
# =====================

op = onnxscript.opset18
msft_op = onnxscript.values.Opset("com.microsoft", 1)


def rotary_match_pattern(x, pos_ids, axis):
"""The pattern to match."""
unsqueeze = op.Unsqueeze(x, axis)
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)

matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, length = msft_op.ConcatTraining(transpose, transpose)

sin = op.Sin(output)
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
cos = op.Cos(output)
cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)
return cast1, cast2


def validate_rotary_mapping(g, matched_nodes, added_nodes) -> bool:
"""The validation post matching.
Returns True to validate the replacement,
False not to apply it.
:param g: model
:param matched_nodes: matched nodes
:param added_nodes: nodes replacing the matched nodes
"""
del g
del matched_nodes
del added_nodes
return True


def rotary_apply_pattern(x, pos_ids, axis):
"""The replacement pattern."""
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)
return part1, part2


###########################
# The rule
# ========
#
# The rule is easy to create.


rule = org.make_pattern_rule(
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
)

################################
# ``validate_rotary_mapping`` always return True.
# This argument can be ignored in that case.

rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)

##########################
# Let's apply it.
rule.apply_to_model(ir_model)


########################
# And finally, we can generate the model.

opt_onx = oir.protobuilder.build_model_proto(ir_model)

########################
# Let's see what it looks like.

for node in opt_onx.graph.node:
print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")

#############################
# What if it fails?
# =================


model = get_rotary_model(True)
ir_model = oir.irbuilder.build_ir(model)

rule.apply_to_model(ir_model)
opt_onx = oir.protobuilder.build_model_proto(ir_model)

print([n.op_type for n in opt_onx.graph.node])

################################
# The match did not happen.
# Let's increase the verbosity.

rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern, verbose=10)

rule.apply_to_model(ir_model)

######################################
# The logs shows every time the algorithm rejected a pattern.
# We can see the following:
#
# ::
#
# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
# --hint--: BACKWARD: different node types
# --pattern
# ConcatTraining(transpose, transpose) -> (output, length)
# -- model
# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
# iteration=1
# --marked-- #2
# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
# len(stacked)=0:[]
#
# Line 673 in file `generic_pattern.py`, the match was rejected.
# It says while comparing two nodes in the backward direction,
# node types do not match.
# It also says that two nodes were actually matched.
4 changes: 4 additions & 0 deletions onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def output_names(self):
def attribute(self):
return self.original_node_proto.attribute

def set_version_if_custom_op(self, version_map: dict[str, int]) -> None:
if self.domain != "" and self.domain in version_map:
self.version = version_map[self.domain]

def get_attribute(self, name: str) -> int | float | None:
return self.attributes.get(name, None)

Expand Down
5 changes: 3 additions & 2 deletions onnxscript/_legacy_ir/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def visit_model(self, model_proto: onnx.ModelProto) -> ir.Model:
self._function_shape_env = visitor.FunctionShapeEnv()
self._function_shape_env.load_from_model_proto(model_proto)
self._ir_version = model_proto.ir_version
version_map = {x.domain: x.version for x in model_proto.opset_import}
self.version_map = {x.domain: x.version for x in model_proto.opset_import}
functions = [self.visit_function(function) for function in model_proto.functions]
self.functions = {function.id: function for function in functions}
graph = self.visit_graph(model_proto.graph)
model = ir.Model()
model.set(model_proto, graph, functions, version_map)
model.set(model_proto, graph, functions, self.version_map)
return model

def visit_graph(self, graph: onnx.GraphProto) -> ir.Graph:
Expand Down Expand Up @@ -122,6 +122,7 @@ def process_initializer(self, init: onnx.TensorProto):

def process_node(self, node):
node_ir = ir.Node(node)
node_ir.set_version_if_custom_op(self.version_map)
self.current_graph_or_function.nodes.append(node_ir)
for name in node.input:
value = self.lookup(name)
Expand Down
1 change: 1 addition & 0 deletions onnxscript/_legacy_ir/protobuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def visit_ir_function(
# function_proto.metadata_props = ir_function.original_function_proto.metadata_props)

for node in ir_function.nodes:
# TODO: deduplicate the opset import of function?
operator_setid_proto = function_proto.opset_import.add()
if node.domain in self.opset_imports:
operator_setid_proto.domain = self.opset_imports[node.domain].domain
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/broadcast_to_matmul_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest
<ir_version: 7, opset_import: [ "" : 17, "pkg.custom": 1]>
agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output)
{
output = afunction (input_x, input_y)
output = pkg.custom.afunction (input_x, input_y)
}
<domain: "pkg.custom", opset_import: [ "" : 17]>
afunction (input_x, input_y) => (output)
Expand Down
8 changes: 7 additions & 1 deletion onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import onnx

from onnxscript._legacy_ir import irbuilder, protobuilder
from onnxscript.optimizer import remove_unused
from onnxscript.optimizer import remove_unused, remove_unused_function
from onnxscript.rewriter import function_rule, pattern
from onnxscript.rewriter.onnxruntime import (
group_normalization_merge_silu,
instance_to_group_normalization,
softmax,
transformers,
Expand All @@ -16,6 +17,8 @@
ORT_PATTERN_REWRITE_RULES = [
*softmax.rules.rules,
*instance_to_group_normalization.rules.rules,
# NOTE: group normalization merge silu should be applied after instance to group normalization
*group_normalization_merge_silu.rules.rules,
]


Expand Down Expand Up @@ -49,5 +52,8 @@ def rewrite(
count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model_ir)
print(f"Applied {count} pattern rewrite rules.")
model = protobuilder.build_model_proto(model_ir)
# TODO: Does it make more sense we run DCE after each rewrite rule applied?
# If so, we need IR to support DCE.
remove_unused.remove_unused_nodes(model)
remove_unused_function.remove_unused_functions(model)
return model
58 changes: 58 additions & 0 deletions onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import logging

from onnxscript.rewriter import pattern

op = pattern.onnxop
msft_op = pattern.msft_op
torch_module_op = pattern.torch_module_op

logger = logging.getLogger(__name__)


def group_normalization_and_silu_submodule(
input,
weight,
bias,
epsilon,
groups,
):
group_norm = msft_op.GroupNorm(
input,
weight,
bias,
activation=0,
channels_last=1,
epsilon=epsilon,
groups=groups,
)
transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2])
return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(transposed)


def group_normalization_with_silu(
input,
weight,
bias,
epsilon,
groups,
):
group_norm = msft_op.GroupNorm(
input,
weight,
bias,
activation=1,
channels_last=1,
epsilon=epsilon,
groups=groups,
)
return op.Transpose(group_norm, perm=[0, 3, 1, 2])


group_normalization_merge_silu_submodule_rule = pattern.RewriteRule(
group_normalization_and_silu_submodule,
group_normalization_with_silu,
)

rules = pattern.RewriteRuleSet([group_normalization_merge_silu_submodule_rule])
Loading

0 comments on commit fc01bfa

Please sign in to comment.