-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update lint rules; Migrate groupnorm commit on "Skip full model shape…
… inference if model > 2GB | feat(optimizer)" [ghstack-poisoned]
- Loading branch information
Showing
11 changed files
with
450 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
"""Onnx Pattern Rewriting. | ||
This script shows how to define a rewriting rule based on patterns. | ||
The objective is to replace some nodes in an onnx model into another | ||
sequence of nodes but more efficient. | ||
First a dummy model | ||
=================== | ||
""" | ||
|
||
import numpy as np | ||
import onnx | ||
import onnx.helper as oh | ||
import onnx.numpy_helper as onh | ||
|
||
import onnxscript | ||
import onnxscript._legacy_ir as oir | ||
import onnxscript.rewriter.generic_pattern as org | ||
|
||
|
||
def get_rotary_model(bad_model=False): | ||
inputs = [ | ||
oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]), | ||
oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]), | ||
oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]), | ||
] | ||
nodes = [ | ||
oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]), | ||
oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1), | ||
oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]), | ||
oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]), | ||
oh.make_node( | ||
"ConcatTrainingBad" if bad_model else "ConcatTraining", | ||
["_onx_transpose0", "_onx_transpose0"], | ||
["_onx_concattraining0", "_onx_concattraining1"], | ||
domain="com.microsoft", | ||
), | ||
oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]), | ||
oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1), | ||
oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]), | ||
oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1), | ||
] | ||
outputs = [ | ||
oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []), | ||
oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []), | ||
] | ||
model = oh.make_model( | ||
oh.make_graph( | ||
nodes, | ||
"experiment", | ||
inputs, | ||
outputs, | ||
), | ||
opset_imports=[ | ||
oh.make_opsetid("", 18), | ||
oh.make_opsetid("com.microsoft", 18), | ||
], | ||
) | ||
return model | ||
|
||
|
||
model = get_rotary_model() | ||
ir_model = oir.irbuilder.build_ir(model) | ||
|
||
|
||
#################################### | ||
# The rewriting pattern | ||
# ===================== | ||
|
||
op = onnxscript.opset18 | ||
msft_op = onnxscript.values.Opset("com.microsoft", 1) | ||
|
||
|
||
def rotary_match_pattern(x, pos_ids, axis): | ||
"""The pattern to match.""" | ||
unsqueeze = op.Unsqueeze(x, axis) | ||
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT) | ||
|
||
matmul = op.MatMul(pos_ids, cast) | ||
transpose = op.Transpose(matmul) | ||
output, length = msft_op.ConcatTraining(transpose, transpose) | ||
|
||
sin = op.Sin(output) | ||
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT) | ||
cos = op.Cos(output) | ||
cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT) | ||
return cast1, cast2 | ||
|
||
|
||
def validate_rotary_mapping(g, matched_nodes, added_nodes) -> bool: | ||
"""The validation post matching. | ||
Returns True to validate the replacement, | ||
False not to apply it. | ||
:param g: model | ||
:param matched_nodes: matched nodes | ||
:param added_nodes: nodes replacing the matched nodes | ||
""" | ||
del g | ||
del matched_nodes | ||
del added_nodes | ||
return True | ||
|
||
|
||
def rotary_apply_pattern(x, pos_ids, axis): | ||
"""The replacement pattern.""" | ||
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) | ||
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) | ||
part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) | ||
return part1, part2 | ||
|
||
|
||
########################### | ||
# The rule | ||
# ======== | ||
# | ||
# The rule is easy to create. | ||
|
||
|
||
rule = org.make_pattern_rule( | ||
rotary_match_pattern, | ||
rotary_apply_pattern, | ||
validate_rotary_mapping, | ||
) | ||
|
||
################################ | ||
# ``validate_rotary_mapping`` always return True. | ||
# This argument can be ignored in that case. | ||
|
||
rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern) | ||
|
||
########################## | ||
# Let's apply it. | ||
rule.apply_to_model(ir_model) | ||
|
||
|
||
######################## | ||
# And finally, we can generate the model. | ||
|
||
opt_onx = oir.protobuilder.build_model_proto(ir_model) | ||
|
||
######################## | ||
# Let's see what it looks like. | ||
|
||
for node in opt_onx.graph.node: | ||
print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}") | ||
|
||
############################# | ||
# What if it fails? | ||
# ================= | ||
|
||
|
||
model = get_rotary_model(True) | ||
ir_model = oir.irbuilder.build_ir(model) | ||
|
||
rule.apply_to_model(ir_model) | ||
opt_onx = oir.protobuilder.build_model_proto(ir_model) | ||
|
||
print([n.op_type for n in opt_onx.graph.node]) | ||
|
||
################################ | ||
# The match did not happen. | ||
# Let's increase the verbosity. | ||
|
||
rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern, verbose=10) | ||
|
||
rule.apply_to_model(ir_model) | ||
|
||
###################################### | ||
# The logs shows every time the algorithm rejected a pattern. | ||
# We can see the following: | ||
# | ||
# :: | ||
# | ||
# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast | ||
# --hint--: BACKWARD: different node types | ||
# --pattern | ||
# ConcatTraining(transpose, transpose) -> (output, length) | ||
# -- model | ||
# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1) | ||
# iteration=1 | ||
# --marked-- #2 | ||
# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320] | ||
# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472] | ||
# len(stacked)=0:[] | ||
# | ||
# Line 673 in file `generic_pattern.py`, the match was rejected. | ||
# It says while comparing two nodes in the backward direction, | ||
# node types do not match. | ||
# It also says that two nodes were actually matched. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
from onnxscript.rewriter import pattern | ||
|
||
op = pattern.onnxop | ||
msft_op = pattern.msft_op | ||
torch_module_op = pattern.torch_module_op | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def group_normalization_and_silu_submodule( | ||
input, | ||
weight, | ||
bias, | ||
epsilon, | ||
groups, | ||
): | ||
group_norm = msft_op.GroupNorm( | ||
input, | ||
weight, | ||
bias, | ||
activation=0, | ||
channels_last=1, | ||
epsilon=epsilon, | ||
groups=groups, | ||
) | ||
transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2]) | ||
return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(transposed) | ||
|
||
|
||
def group_normalization_with_silu( | ||
input, | ||
weight, | ||
bias, | ||
epsilon, | ||
groups, | ||
): | ||
group_norm = msft_op.GroupNorm( | ||
input, | ||
weight, | ||
bias, | ||
activation=1, | ||
channels_last=1, | ||
epsilon=epsilon, | ||
groups=groups, | ||
) | ||
return op.Transpose(group_norm, perm=[0, 3, 1, 2]) | ||
|
||
|
||
group_normalization_merge_silu_submodule_rule = pattern.RewriteRule( | ||
group_normalization_and_silu_submodule, | ||
group_normalization_with_silu, | ||
) | ||
|
||
rules = pattern.RewriteRuleSet([group_normalization_merge_silu_submodule_rule]) |
Oops, something went wrong.