Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft - WIP] Add rotary embedding fusion rule (part 1) #1981

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
22 changes: 22 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,28 @@
return default


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
shape = _get_input(node, 1)
if input is None or shape is None:
return None

Check warning on line 314 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L314

Added line #L314 was not covered by tests
input_shape = input.shape
if input_shape is None:
return None

Check warning on line 317 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L317

Added line #L317 was not covered by tests
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None

Check warning on line 320 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L320

Added line #L320 was not covered by tests
shape_value = _get_numpy_value(shape)
if shape_value is None:
return None

Check warning on line 323 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L323

Added line #L323 was not covered by tests
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
return op.Identity(input)
return None


@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
Expand Down
27 changes: 27 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Licensed under the MIT License.
from __future__ import annotations

import math
from typing import Callable

import numpy as np

import onnxscript.ir as ir
Expand Down Expand Up @@ -77,3 +80,27 @@
if np_val is not None and np_val.size == 1:
return np_val.item()
return None


def is_singleton_value(
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
) -> bool:
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
scalar = get_singleton_value(val)
if scalar is None:
return False

Check warning on line 91 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L91

Added line #L91 was not covered by tests
if isinstance(expected, Callable):

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "isinstance" has incompatible type ""; expected "_ClassInfo" To disable, use # type: ignore[arg-type]
return expected(scalar)

Check failure

Code scanning / lintrunner

MYPY/operator Error

"float" not callable To disable, use # type: ignore[operator]

Check failure

Code scanning / lintrunner

MYPY/operator Error

"int" not callable To disable, use # type: ignore[operator]
if isinstance(expected, int):
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rtol=rtol)

Check warning on line 98 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L97-L98

Added lines #L97 - L98 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Unexpected keyword argument "rtol" for "isclose" To disable, use # type: ignore[call-arg]

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "isclose" has incompatible type "float | Callable[..., Any]"; expected "SupportsFloat | SupportsIndex" To disable, use # type: ignore[arg-type]


def has_rank(value: ir.Value | None, rank: int) -> bool:
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
if value is None:
return False
shape = value.shape
return (shape is not None) and (shape.rank() == rank)

Check warning on line 106 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L104-L106

Added lines #L104 - L106 were not covered by tests
12 changes: 12 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization

__all__ = [
"fuse_rms_normalization",
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
]
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnxruntime/xformers/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ort_run(model_name: str, model, inputs):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)
_save(model, model_path)
# Run model
session = onnxruntime.InferenceSession(model_path, providers=providers)
ort_outputs = session.run(None, inputs)
Expand Down
73 changes: 73 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.rewriter import _ir_utils, pattern

# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops.

# Original code (from transformers) for computing cos/sin cache for RoPE:
# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135
# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# position_ids_expanded = position_ids[:, None, :].float()
# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# emb = torch.cat((freqs, freqs), dim=-1)
# cos = emb.cos()
# sin = emb.sin()


class CosSinCacheFusion(pattern.RewriteRuleClassBase):
def __init__(self, name: str, max_pos_id: int):
super().__init__(name)
self._max_pos_id = max_pos_id
self.remove_nodes = False
Fixed Show fixed Hide fixed

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute remove_nodes, which was previously defined in superclass
RewriteRuleClassBase
.

def pattern(self, op, x, inv_freq, position_ids):
position_ids_expanded = op.Unsqueeze(position_ids, 1)
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq, position_ids_expanded)
freqs = op.Transpose(freqs, perm=[0, 2, 1])
emb = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, 1) # convert
sin_4d = op.Unsqueeze(sin, 1)
return op.RotaryEmbedding(x, cos_4d, sin_4d, interleaved=0, _domain="ai.onnxruntime.fusion")

def check(self, context, inv_freq, position_ids, **_):
if not _ir_utils.has_rank(position_ids, 2):
return False

Check warning on line 42 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L42

Added line #L42 was not covered by tests
if not _ir_utils.has_rank(inv_freq, 3):
return False
inv_freq_shape = inv_freq.shape

Check warning on line 45 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L44-L45

Added lines #L44 - L45 were not covered by tests
if inv_freq.const_value is None:
return False
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1

Check warning on line 48 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L47-L48

Added lines #L47 - L48 were not covered by tests

def rewrite(self, op, x, inv_freq, position_ids, **_):
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
cos_value = np.concatenate([cos_value, cos_value], axis=-1)
sin_value = np.sin(angles)
sin_value = np.concatenate([sin_value, sin_value], axis=-1)
cos_2d = op.Constant(value=ir.tensor(cos_value))

Check warning on line 58 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L51-L58

Added lines #L51 - L58 were not covered by tests
# cos = op.Gather(cos_2d, position_ids, axis=0)
sin_2d = op.Constant(value=ir.tensor(sin_value))

Check warning on line 60 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L60

Added line #L60 was not covered by tests
# sin = op.Gather(sin_2d, position_ids, axis=0)
return op.RotaryEmbedding(x, cos_2d, sin_2d, position_ids, interleaved=0, _domain="ai.onnxruntime.fusion")

Check warning on line 62 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L62

Added line #L62 was not covered by tests


_rule = CosSinCacheFusion.rule("CosSinCache", 2048)

cos_sin_cache_rules = pattern.RewriteRuleSet([_rule])


def fuse_cos_sin_cache(model: ir.Model) -> int:
count = cos_sin_cache_rules.apply_to_model(model)
print(f"CosSinCache count: {count}")
return count

Check warning on line 73 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L71-L73

Added lines #L71 - L73 were not covered by tests
29 changes: 29 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run


class TestCosSinCacheTransform(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)
count = fuse_rotary_embedding(model)
Fixed Show fixed Hide fixed
self.assertGreater(count, 0)
count = fuse_cos_sin_cache(model)
self.assertGreater(count, 0)
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)

Check warning on line 25 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py#L22-L25

Added lines #L22 - L25 were not covered by tests


if __name__ == "__main__":
unittest.main()

Check warning on line 29 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py#L29

Added line #L29 was not covered by tests
8 changes: 2 additions & 6 deletions onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,10 @@ def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
cast_input: Whether to cast input to do the normalization in a different precision.
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
"""
self._name = name
super().__init__(name=name)
self._cast_input = cast_input
self._cast_normalized = cast_normalized

@property
def name(self):
return self._name

def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
if self._cast_input:
x = op.Cast(x, to=compute_dtype)
Expand Down Expand Up @@ -95,5 +91,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):


def fuse_rms_normalization(model: ir.Model) -> None:
count = rms_normalization_ruleset.apply_to_model(model, verbose=5)
count = rms_normalization_ruleset.apply_to_model(model)
print(f"RMS Normalization count: {count}")
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@

import unittest

import onnx

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization


def model_repr(self):
return f"Model({self.graph.name})"


onnx.ModelProto.__repr__ = model_repr


class TestRmsNormalization(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
Expand Down
58 changes: 58 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.rewriter import _ir_utils, pattern

# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern
# for full rotation without interleaving.
# TODO(rama): Add pattern variations to handle other cases.

# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet,
# so it can't be tested by running against ORT. Unfortunately, this is the new pattern out
# of current version of transformers (not yet supported by ORT).


def _rotate_half_pattern(op, x, start1, end1, start2, end2):
# Slice(input, starts, ends, axes, steps)
x1 = op.Slice(x, start1, end1, [3], [1])
x2 = op.Slice(x, start2, end2, [3], [1])
minus_x2 = op.Neg(x2)
rotated_x = op.Concat(minus_x2, x1, axis=-1)
return rotated_x


class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

def check(self, op, x, start1, end1, start2, end2, **_):
# x needs to be a 4D tensor with known last dimension size (== head_size)
if x is None or x.shape is None or len(x.shape) != 4:
return False

Check warning on line 33 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L33

Added line #L33 was not covered by tests
head_size = x.shape[3]
if not isinstance(head_size, int):
return False

Check warning on line 36 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L36

Added line #L36 was not covered by tests
half_head_size = head_size // 2

# Check that x is being split into two equal halves of size half_head_size
return (
_ir_utils.is_singleton_value(start1, 0)
and _ir_utils.is_singleton_value(end1, half_head_size)
and _ir_utils.is_singleton_value(start2, half_head_size)
and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size)
)

def rewrite(self, op, x, cos, sin, **_):
return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion")


_rule = RotaryEmbeddingFusion.rule()

rotary_embedding_rules = pattern.RewriteRuleSet([_rule])


def fuse_rotary_embedding(model: ir.Model) -> None:
count = rotary_embedding_rules.apply_to_model(model)
print(f"Rotary Embedding count: {count}")
23 changes: 23 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding


class TestRotaryEmbedding(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
fuse_rotary_embedding(model)
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)


if __name__ == "__main__":
unittest.main()

Check warning on line 23 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py#L23

Added line #L23 was not covered by tests
Loading
Loading