From 0aed232233bbac5bba0e59c3c4fd50349c7489ca Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 13 Dec 2024 21:24:56 -0800 Subject: [PATCH] RMS Normalization and Skip RMS Normalization fusion optimizations (#1974) Implements RMS Normalization and Skip RMS Normalization fusion optimizations (for use of onnxruntime custom fused ops for these). --- .lintrunner.toml | 1 + .../rewriter/onnxruntime/xformers/__init__.py | 3 + .../onnxruntime/xformers/_smollm_1layer.py | 253 ++++++++++++++++++ .../onnxruntime/xformers/_test_models.py | 122 +++++++++ .../onnxruntime/xformers/_test_utils.py | 42 +++ .../onnxruntime/xformers/rms_normalization.py | 99 +++++++ .../xformers/rms_normalization_test.py | 37 +++ .../xformers/skip_normalization.py | 46 ++++ .../xformers/skip_normalization_test.py | 28 ++ onnxscript/rewriter/pattern.py | 30 +++ 10 files changed, 661 insertions(+) create mode 100644 onnxscript/rewriter/onnxruntime/xformers/__init__.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/_test_models.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/_test_utils.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/skip_normalization_test.py diff --git a/.lintrunner.toml b/.lintrunner.toml index 9b874e221..6679927e9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -50,6 +50,7 @@ exclude_patterns = [ 'onnxscript/optimizer/_legacy/constant_folding.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME 'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME + 'onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py', # onnxscript code 'onnxscript/_legacy_ir/irbuilder.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME 'onnxscript/tools/function_unittest_producer.py', # FIXME diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py new file mode 100644 index 000000000..44b5591d8 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations diff --git a/onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py b/onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py new file mode 100644 index 000000000..c5bf35046 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +A one-layer SmolLM model test case. +This is an onnxscript version of the model. +""" + +import numpy +from onnx.helper import make_tensor + +import onnxscript.ir as ir +from onnxscript import script +from onnxscript.onnx_opset import opset18 +from onnxscript.onnx_types import FLOAT, INT64 + + +def make_model( + input_layernorm_weight_0, + post_attention_layernorm_weight0, + norm_weight, + head_weight, + self_attn_q_proj_weight0, + self_attn_k_proj_weight0, + self_attn_v_proj_weight0, + self_attn_o_proj_weight0, + mlp_gate_proj_weight0, + mlp_up_proj_weight0, + mlp_down_proj_weight0, +): + @script() + def main_graph( + input0: INT64[1, 10], input1: FLOAT[1, 10], input2: INT64[1, 10] + ) -> (FLOAT[1, 10, 49152], FLOAT[1, 32, 10, 64], FLOAT[1, 32, 10, 64]): + model_layers_0_input_layernorm_weight = opset18.Constant( + value=input_layernorm_weight_0 + ) + model_layers_0_post_attention_layernorm_weight = opset18.Constant( + value=post_attention_layernorm_weight0 + ) + model_norm_weight = opset18.Constant(value=norm_weight) + lm_head_weight = opset18.Constant(value=head_weight) + model_layers_0_self_attn_q_proj_weight = opset18.Constant( + value=self_attn_q_proj_weight0 + ) + model_layers_0_self_attn_k_proj_weight = opset18.Constant( + value=self_attn_k_proj_weight0 + ) + model_layers_0_self_attn_v_proj_weight = opset18.Constant( + value=self_attn_v_proj_weight0 + ) + model_layers_0_self_attn_o_proj_weight = opset18.Constant( + value=self_attn_o_proj_weight0 + ) + model_layers_0_mlp_gate_proj_weight = opset18.Constant(value=mlp_gate_proj_weight0) + model_layers_0_mlp_up_proj_weight = opset18.Constant(value=mlp_up_proj_weight0) + model_layers_0_mlp_down_proj_weight = opset18.Constant(value=mlp_down_proj_weight0) + + embedding = opset18.Gather(lm_head_weight, input0, axis=0) + minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38]) + mask_10x10 = opset18.Trilu(minus_inf_10x10, 1) + slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10]) + unsqueeze_2 = opset18.Unsqueeze(input1, 1) + unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2) + add = slice_5 + unsqueeze_3 + eq = add == 0.0 + slice_10 = slice_5 + masked_fill = opset18.Where(eq, -3.4028235e38, slice_10) + val_179 = opset18.Transpose(masked_fill, perm=[2, 1, 0, 3]) + slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3]) + val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3]) + slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3]) + unsqueeze_6 = opset18.Unsqueeze(input2, 1) + _to_copy_1 = opset18.Cast(unsqueeze_6, to=1) + view_1 = opset18.Constant( + value=make_tensor( + "value", + 1, + dims=[1, 32, 1], + vals=[ + 1.0, + 0.7498942017555237, + 0.5623413324356079, + 0.4216965138912201, + 0.3162277638912201, + 0.23713736236095428, + 0.17782793939113617, + 0.1333521455526352, + 0.10000000149011612, + 0.07498941570520401, + 0.05623412877321243, + 0.04216964915394783, + 0.03162277862429619, + 0.0237137358635664, + 0.017782794311642647, + 0.01333521492779255, + 0.009999999776482582, + 0.007498942315578461, + 0.005623413249850273, + 0.0042169648222625256, + 0.003162277862429619, + 0.0023713738191872835, + 0.0017782794311642647, + 0.0013335214462131262, + 0.0010000000474974513, + 0.0007498941849917173, + 0.000562341301701963, + 0.00042169648804701865, + 0.0003162277862429619, + 0.0002371373848291114, + 0.00017782794020604342, + 0.0001333521504420787, + ], + ) + ) + view_2 = opset18.Reshape(_to_copy_1, [1, 1, 10], allowzero=0) + bmm = view_1 @ view_2 + view_3 = opset18.Reshape(bmm, [1, 32, 10], allowzero=0) + transpose = opset18.Transpose(view_3, perm=[0, 2, 1]) + cat = opset18.Concat(transpose, transpose, axis=-1) + cos = opset18.Cos(cat) + sin = opset18.Sin(cat) + pow_1 = embedding**2.0 + mean = opset18.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0) + add_1 = mean + 1e-05 + val_244 = opset18.Sqrt(add_1) + rsqrt = opset18.Reciprocal(val_244) + mul_3 = embedding * rsqrt + mul_4 = model_layers_0_input_layernorm_weight * mul_3 + t = opset18.Transpose(model_layers_0_self_attn_q_proj_weight, perm=[1, 0]) + view_5 = mul_4 @ t + t_1 = opset18.Transpose(model_layers_0_self_attn_k_proj_weight, perm=[1, 0]) + view_7 = mul_4 @ t_1 + t_2 = opset18.Transpose(model_layers_0_self_attn_v_proj_weight, perm=[1, 0]) + view_9 = mul_4 @ t_2 + view_10 = opset18.Reshape(view_5, [1, 10, 32, 64], allowzero=0) + transpose_1 = opset18.Transpose(view_10, perm=[0, 2, 1, 3]) + view_11 = opset18.Reshape(view_7, [1, 10, 32, 64], allowzero=0) + transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3]) + view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0) + transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) + unsqueeze_7 = opset18.Unsqueeze(cos, 1) + unsqueeze_8 = opset18.Unsqueeze(sin, 1) + mul_5 = transpose_1 * unsqueeze_7 + val_267 = opset18.Constant(value_ints=[1]) + slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267) + val_277 = opset18.Constant(value_ints=[1]) + slice_20 = opset18.Slice(transpose_1, [32], [9223372036854775807], [3], val_277) + neg = opset18.Neg(slice_20) + cat_1 = opset18.Concat(neg, slice_19, axis=-1) + mul_6 = cat_1 * unsqueeze_8 + add_2 = mul_5 + mul_6 + mul_7 = transpose_2 * unsqueeze_7 + val_287 = opset18.Constant(value_ints=[1]) + slice_21 = opset18.Slice(transpose_2, [0], [32], [3], val_287) + val_297 = opset18.Constant(value_ints=[1]) + slice_22 = opset18.Slice(transpose_2, [32], [9223372036854775807], [3], val_297) + neg_1 = opset18.Neg(slice_22) + cat_2 = opset18.Concat(neg_1, slice_21, axis=-1) + mul_8 = cat_2 * unsqueeze_8 + add_3 = mul_7 + mul_8 + val_346 = opset18.Reshape(add_3, [-1, 10, 64], allowzero=0) + val_347 = opset18.Transpose(val_346, perm=[0, 2, 1]) + val_349 = opset18.Reshape(val_347, [1, 32, 64, 10], allowzero=0) + val_351 = add_2 * [0.35355338] + val_353 = val_349 * [0.35355338] + val_354 = val_351 @ val_353 + val_355 = val_354 + slice_scatter_1 + val_356 = opset18.Softmax(val_355, axis=-1) + getitem = val_356 @ transpose_3 + transpose_4 = opset18.Transpose(getitem, perm=[0, 2, 1, 3]) + view_13 = opset18.Reshape(transpose_4, [1, 10, -1], allowzero=0) + t_3 = opset18.Transpose(model_layers_0_self_attn_o_proj_weight, perm=[1, 0]) + view_15 = view_13 @ t_3 + add_4 = embedding + view_15 + pow_2 = add_4**2.0 + mean_1 = opset18.ReduceMean(pow_2, [-1], keepdims=1, noop_with_empty_axes=0) + add_5 = mean_1 + 1e-05 + val_379 = opset18.Sqrt(add_5) + rsqrt_1 = opset18.Reciprocal(val_379) + mul_9 = add_4 * rsqrt_1 + mul_10 = model_layers_0_post_attention_layernorm_weight * mul_9 + t_4 = opset18.Transpose(model_layers_0_mlp_gate_proj_weight, perm=[1, 0]) + view_17 = mul_10 @ t_4 + val_383 = opset18.Sigmoid(view_17) + silu = view_17 * val_383 + t_5 = opset18.Transpose(model_layers_0_mlp_up_proj_weight, perm=[1, 0]) + view_19 = mul_10 @ t_5 + mul_11 = silu * view_19 + t_6 = opset18.Transpose(model_layers_0_mlp_down_proj_weight, perm=[1, 0]) + view_21 = mul_11 @ t_6 + add_6 = add_4 + view_21 + pow_3 = add_6**2.0 + mean_2 = opset18.ReduceMean(pow_3, [-1], keepdims=1, noop_with_empty_axes=0) + add_7 = mean_2 + 1e-05 + val_391 = opset18.Sqrt(add_7) + rsqrt_2 = opset18.Reciprocal(val_391) + mul_12 = add_6 * rsqrt_2 + mul_13 = model_norm_weight * mul_12 + t_7 = opset18.Transpose(lm_head_weight, perm=[1, 0]) + view_23 = mul_13 @ t_7 + _to_copy_12 = opset18.Identity(view_23) + return _to_copy_12, add_3, transpose_3 + + model = main_graph.to_model_proto() + return model + + +def make_model_with_random_weights(): + input_layernorm_weight_0 = numpy.random.rand(2048).astype(numpy.float32) + post_attention_layernorm_weight0 = numpy.random.rand(2048).astype(numpy.float32) + norm_weight = numpy.random.rand(2048).astype(numpy.float32) + head_weight = numpy.random.rand(49152, 2048).astype(numpy.float32) + self_attn_q_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32) + self_attn_k_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32) + self_attn_v_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32) + self_attn_o_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32) + mlp_gate_proj_weight0 = numpy.random.rand(8192, 2048).astype(numpy.float32) + mlp_up_proj_weight0 = numpy.random.rand(8192, 2048).astype(numpy.float32) + mlp_down_proj_weight0 = numpy.random.rand(2048, 8192).astype(numpy.float32) + model = make_model( + input_layernorm_weight_0, + post_attention_layernorm_weight0, + norm_weight, + head_weight, + self_attn_q_proj_weight0, + self_attn_k_proj_weight0, + self_attn_v_proj_weight0, + self_attn_o_proj_weight0, + mlp_gate_proj_weight0, + mlp_up_proj_weight0, + mlp_down_proj_weight0, + ) + return model + + +class _SmollmTestData: + def get_onnx_model(self): + if not hasattr(self, "_onnx_model"): + model_proto = make_model_with_random_weights() + model = ir.serde.deserialize_model(model_proto) + self._onnx_model = model + return self._onnx_model + + def get_ort_inputs(self): + if not hasattr(self, "_ort_inputs"): + inputs = { + "input0": numpy.random.randint(0, 49152, (1, 10)).astype(numpy.int64), + "input1": numpy.ones((1, 10), dtype=numpy.float32), + "input2": numpy.arange(10, dtype=numpy.int64).reshape(1, 10), + } + self._ort_inputs = inputs + return self._ort_inputs diff --git a/onnxscript/rewriter/onnxruntime/xformers/_test_models.py b/onnxscript/rewriter/onnxruntime/xformers/_test_models.py new file mode 100644 index 000000000..64f0c396d --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/_test_models.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import tempfile + +import numpy as np +import onnxruntime +import torch +import transformers +from transformers import LlamaConfig + +import onnxscript.ir as ir +import onnxscript.ir._io as io +import onnxscript.optimizer + +# Create a LlamaConfig object with the desired parameters +_config = LlamaConfig( + _name_or_path="HuggingFaceTB/SmolLM-1.7B", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=0, + eos_token_id=0, + hidden_act="silu", + hidden_size=2048, + initializer_range=0.02, + intermediate_size=8192, + max_position_embeddings=2048, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=1, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype="float32", + transformers_version="4.37.2", + use_cache=True, + vocab_size=49152, +) + +# Dimensions for inputs: +_batch_size = 1 +_seq_len = 10 +_hidden_size = _config.hidden_size +_num_attention_heads = _config.num_attention_heads +dim = _hidden_size // _num_attention_heads +_vocab_size = _config.vocab_size + + +class _SmollmTestData: + def __init__(self): + pass + + def get_torch_model(self): + if not hasattr(self, "_torch_model"): + model = transformers.LlamaForCausalLM(_config) + model.eval() + self._torch_model = model + return self._torch_model + + def get_onnx_model(self) -> ir.Model: + model = self.get_torch_model() + inputs = self.get_inputs() + input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None] + exported = torch.onnx.export( + model, inputs, input_names=input_names, dynamo=True, fallback=True + ) + # ORT Transformer optimizations are applied after basic optimization. + exported_model = exported.model # type: ignore[union-attr] + onnxscript.optimizer.optimize(exported_model) + return exported_model + + def get_inputs(self): + if not hasattr(self, "_inputs"): + input_ids = torch.randint(0, _vocab_size, (_batch_size, _seq_len)).to(torch.int64) + attention_mask = torch.ones(input_ids.shape) + position_ids = torch.arange(0, input_ids.size(-1)).unsqueeze(0) + self._inputs = (input_ids, attention_mask, position_ids) + return self._inputs + + def get_torch_outputs(self): + output = self.get_torch_model()(*self.get_inputs()) + logits = output.logits + past_key_value = output.past_key_values[0] + key = past_key_value[0] + value = past_key_value[1] + return (logits.detach().numpy(), key.detach().numpy(), value.detach().numpy()) + + def get_ort_inputs(self): + inputs = self.get_inputs() + return { + f"input{i}": input.numpy() for i, input in enumerate(inputs) if input is not None + } + + +def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2): + providers = ["CPUExecutionProvider"] + with tempfile.TemporaryDirectory() as temp_dir: + model_path = os.path.join(temp_dir, f"{model_name}.onnx") + io.save(model, model_path) + # Run model + session = onnxruntime.InferenceSession(model_path, providers=providers) + ort_outputs = session.run(None, inputs) + + for i, (baseline_output, optimized_output) in enumerate( + zip(expected_outputs, ort_outputs) + ): + try: + np.testing.assert_equal(baseline_output.shape, optimized_output.shape) + np.testing.assert_allclose( + baseline_output, optimized_output, rtol=rtol, atol=atol + ) + except AssertionError as e: + print( + f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" + ) + raise diff --git a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py new file mode 100644 index 000000000..0b4e2c55f --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import tempfile + +import numpy as np +import onnx +import onnxruntime + +import onnxscript.ir as ir +import onnxscript.ir._io as io + + +def _save(model, modelpath): + if isinstance(model, onnx.ModelProto): + onnx.save(model, modelpath) + else: + assert isinstance(model, ir.Model) + io.save(model, modelpath) + + +def ort_run(model_name: str, model, inputs): + providers = ["CPUExecutionProvider"] + with tempfile.TemporaryDirectory() as temp_dir: + model_path = os.path.join(temp_dir, f"{model_name}.onnx") + io.save(model, model_path) + # Run model + session = onnxruntime.InferenceSession(model_path, providers=providers) + ort_outputs = session.run(None, inputs) + return ort_outputs + + +def assert_allclose(outputs, expected_outputs, rtol=1e-2, atol=1e-2): + for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)): + try: + np.testing.assert_equal(baseline_output.shape, optimized_output.shape) + np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol) + except AssertionError as e: + print(f"Failed for output {i} with rtol={rtol} and atol={atol}\n{e}") + raise diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py new file mode 100644 index 000000000..1f7a96df1 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.rewriter import _ir_utils, pattern + +""" +RMS Normalization: This is referred to as SimplifiedLayerNormalization in the ORT codebase. +See https://github.com/microsoft/onnxruntime/blob/6d9636f07cccdb6e4ac453087ad54c3bc9854d50/onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L2981 + +Key points for the fusion optimization: +* Input and scale are allowed to be of different types. +* The normalization of the input can be done in a different precision than the input type, +which is also the precision of reciprocal_rms returned by operation. +* Input (x) must be: float or double or float16 or bfloat16 +* Scale must be: float or double or float16 or bfloat16 +* Normalization precision must be float or double +""" + +float_types = [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, +] +fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE] + + +class RmsNormFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool): + """ + Args: + name: Name of the rule. + cast_input: Whether to cast input to do the normalization in a different precision. + cast_normalized: Whether to cast the normalized output to the target dtype (same as scale). + """ + self._name = name + self._cast_input = cast_input + self._cast_normalized = cast_normalized + + @property + def name(self): + return self._name + + def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): + if self._cast_input: + x = op.Cast(x, to=compute_dtype) + x_square = op.Pow(x, 2.0) + mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) + mean_square_plus_epsilon = op.Add(mean_square, epsilon) + rms = op.Sqrt(mean_square_plus_epsilon) + reciprocal_rms = op.Reciprocal(rms) + normalized = op.Mul(x, reciprocal_rms) + if self._cast_normalized: + normalized = op.Cast(normalized, to=target_dtype) + return op.Mul(scale, normalized) + + def check(self, op, x, scale, epsilon, compute_dtype, target_dtype): + """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" + # epsilon must be a scalar + epsilon_value = _ir_utils.get_singleton_value(epsilon) + if not isinstance(epsilon_value, float): # TODO: support other types + return False + # input and output must be same dtype + if x.dtype not in float_types: + return False + if scale.dtype not in float_types: + return False + stash_dtype = compute_dtype.value if self._cast_input else x.dtype + if stash_dtype not in fp_float_types: + return False + return True + + def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): + stash_dtype = compute_dtype.value if self._cast_input else x.dtype + # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. + # No need to use com.microsoft domain here. + return op.SimplifiedLayerNormalization( + x, + scale, + axis=-1, + epsilon=_ir_utils.get_singleton_value(epsilon), + stash_type=stash_dtype, + ) + + +_rule_0 = RmsNormFusion.rule("RmsNorm-0", cast_input=True, cast_normalized=True) +_rule_1 = RmsNormFusion.rule("RmsNorm-1", cast_input=False, cast_normalized=True) +_rule_2 = RmsNormFusion.rule("RmsNorm-2", cast_input=True, cast_normalized=False) +_rule_3 = RmsNormFusion.rule("RmsNorm-3", cast_input=False, cast_normalized=False) + +rms_normalization_rules = [_rule_0, _rule_1, _rule_2, _rule_3] +rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) + + +def fuse_rms_normalization(model: ir.Model) -> None: + count = rms_normalization_ruleset.apply_to_model(model, verbose=5) + print(f"RMS Normalization count: {count}") diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py new file mode 100644 index 000000000..79a966838 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization + + +def model_repr(self): + return f"Model({self.graph.name})" + + +onnx.ModelProto.__repr__ = model_repr + + +class TestRmsNormalization(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + fuse_rms_normalization(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("SimplifiedLayerNormalization", op_types) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py new file mode 100644 index 000000000..c298a0aaf --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import pattern +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import rms_normalization_rules + + +def _skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): + skip_sum = op.Add(input, skip) + normalized = op.SimplifiedLayerNormalization( + skip_sum, + gamma, + axis=-1, + epsilon=epsilon, + stash_type=stash_type, + ) + return normalized, skip_sum + + +def _skip_normalization(op, input, skip, gamma, epsilon, stash_type): + if stash_type.value != 1: # FLOAT type + return None + normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( + input, + skip, + gamma, + epsilon=epsilon, + _outputs=4, + _domain="com.microsoft", + ) + return normalized, skip_sum + + +_rule = pattern.RewriteRule( + _skip_norm_pattern, _skip_normalization, matcher=pattern.SimplePatternMatcher +) + +skip_normalization_rules = [_rule] +normalization_rules = rms_normalization_rules + skip_normalization_rules +normalization_ruleset = pattern.RewriteRuleSet(normalization_rules) + + +def fuse_normalization(model): + count = normalization_ruleset.apply_to_model(model) + print(f"Normalization count: {count}") diff --git a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization_test.py b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization_test.py new file mode 100644 index 000000000..3873ccfc8 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization_test.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run +from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization + + +class TestSkipNormalization(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + fuse_normalization(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("SkipSimplifiedLayerNormalization", op_types) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 66d9b3196..b9d5d002a 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1320,6 +1320,10 @@ def try_rewrite( match = self._matcher.match(model, graph_or_function, node, verbose=verbose) if match: context = None # TODO(rama) + for var in self._target_pattern.inputs: + if var.name is not None: + if var.name not in match.bindings: + match.bindings[var.name] = None if not self._condition_function(context, **match.bindings): return None replacement_subgraph = self._replacement_pattern.get_replacement(match) @@ -1428,6 +1432,32 @@ def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None): ) +# Variation of RewriteRuleAsClass that is based on instance methods instead of class methods. +# Useful to implement a family of rules to support pattern variations. +# TODO: cleanup the naming conventions for these inter-related classes. +class RewriteRuleClassBase: + @classmethod + def rule(cls, *args, **kwargs): + instance = cls(*args, **kwargs) + return RewriteRule( + instance.pattern, instance.rewrite, instance.check, name=instance.name + ) + + @property + def name(self): + """Default implementation of name property.""" + return self.__class__.__name__ + + def pattern(self, op, *args, **kwargs): + raise NotImplementedError("Method 'pattern' must be implemented by derived class.") + + def check(self, op, *args, **kwargs): + raise NotImplementedError("Method 'check' must be implemented by derived class.") + + def rewrite(self, op, *args, **kwargs): + raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") + + class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: if commute: