-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft - WIP] Add rotary embedding fusion rule (part 1) #1981
Open
gramalingam
wants to merge
12
commits into
main
Choose a base branch
from
rama/fuse-attn
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 11 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
8de7231
First version
gramalingam a20b903
Add rotary embedding
gramalingam b8f7a08
Remove SDPA
gramalingam 315c94e
Add comment
gramalingam 2219fd3
Remove MHA
gramalingam f77f0e7
Merge branch 'main' into rama/fuse-attn
gramalingam 5ec9d1e
Add rewrite for cos-sin computation
gramalingam 90f0b7b
Merge branch 'rama/fuse-attn' of https://github.com/microsoft/onnx-sc…
gramalingam 1fdc19b
Run lint
gramalingam eb916b8
Add cos sin test
gramalingam d874dbc
Extend rewriter to support node reuse
gramalingam a745039
Minor fixes
gramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,15 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache | ||
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | ||
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization | ||
|
||
__all__ = [ | ||
"fuse_rms_normalization", | ||
"fuse_normalization", | ||
"fuse_rotary_embedding", | ||
"fuse_cos_sin_cache", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.rewriter import _ir_utils, pattern | ||
|
||
# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. | ||
|
||
# Original code (from transformers) for computing cos/sin cache for RoPE: | ||
# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 | ||
# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | ||
# position_ids_expanded = position_ids[:, None, :].float() | ||
# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | ||
# emb = torch.cat((freqs, freqs), dim=-1) | ||
# cos = emb.cos() | ||
# sin = emb.sin() | ||
|
||
|
||
class CosSinCacheFusion(pattern.RewriteRuleClassBase): | ||
def __init__(self, name: str, max_pos_id: int): | ||
super().__init__(name) | ||
self._max_pos_id = max_pos_id | ||
self.remove_nodes = False | ||
|
||
|
||
def pattern(self, op, x, inv_freq, position_ids): | ||
position_ids_expanded = op.Unsqueeze(position_ids, 1) | ||
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) | ||
freqs = op.MatMul(inv_freq, position_ids_expanded) | ||
freqs = op.Transpose(freqs, perm=[0, 2, 1]) | ||
emb = op.Concat(freqs, freqs, axis=-1) | ||
cos = op.Cos(emb) | ||
sin = op.Sin(emb) | ||
cos_4d = op.Unsqueeze(cos, 1) # convert | ||
sin_4d = op.Unsqueeze(sin, 1) | ||
return op.RotaryEmbedding(x, cos_4d, sin_4d, interleaved=0, _domain="ai.onnxruntime.fusion") | ||
|
||
def check(self, context, inv_freq, position_ids, **_): | ||
if not _ir_utils.has_rank(position_ids, 2): | ||
return False | ||
if not _ir_utils.has_rank(inv_freq, 3): | ||
return False | ||
inv_freq_shape = inv_freq.shape | ||
if inv_freq.const_value is None: | ||
return False | ||
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 | ||
|
||
def rewrite(self, op, x, inv_freq, position_ids, **_): | ||
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) | ||
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) | ||
angles = np.matmul(pos_id_range, inv_freq_values) | ||
cos_value = np.cos(angles) | ||
cos_value = np.concatenate([cos_value, cos_value], axis=-1) | ||
sin_value = np.sin(angles) | ||
sin_value = np.concatenate([sin_value, sin_value], axis=-1) | ||
cos_2d = op.Constant(value=ir.tensor(cos_value)) | ||
# cos = op.Gather(cos_2d, position_ids, axis=0) | ||
sin_2d = op.Constant(value=ir.tensor(sin_value)) | ||
# sin = op.Gather(sin_2d, position_ids, axis=0) | ||
return op.RotaryEmbedding(x, cos_2d, sin_2d, position_ids, interleaved=0, _domain="ai.onnxruntime.fusion") | ||
|
||
|
||
_rule = CosSinCacheFusion.rule("CosSinCache", 2048) | ||
|
||
cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) | ||
|
||
|
||
def fuse_cos_sin_cache(model: ir.Model) -> int: | ||
count = cos_sin_cache_rules.apply_to_model(model) | ||
print(f"CosSinCache count: {count}") | ||
return count | ||
29 changes: 29 additions & 0 deletions
29
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import onnxscript.optimizer | ||
from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding | ||
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData | ||
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run | ||
|
||
|
||
class TestCosSinCacheTransform(unittest.TestCase): | ||
def test_smollm(self): | ||
smollm_test = _SmollmTestData() | ||
model = smollm_test.get_onnx_model() | ||
onnxscript.optimizer.optimize(model) | ||
inputs = smollm_test.get_ort_inputs() | ||
original_outputs = ort_run("original", model, inputs) | ||
count = fuse_rotary_embedding(model) | ||
|
||
self.assertGreater(count, 0) | ||
count = fuse_cos_sin_cache(model) | ||
self.assertGreater(count, 0) | ||
new_outputs = ort_run("optimized", model, inputs) | ||
assert_allclose(new_outputs, original_outputs) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.rewriter import _ir_utils, pattern | ||
|
||
# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern | ||
# for full rotation without interleaving. | ||
# TODO(rama): Add pattern variations to handle other cases. | ||
|
||
# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet, | ||
# so it can't be tested by running against ORT. Unfortunately, this is the new pattern out | ||
# of current version of transformers (not yet supported by ORT). | ||
|
||
|
||
def _rotate_half_pattern(op, x, start1, end1, start2, end2): | ||
# Slice(input, starts, ends, axes, steps) | ||
x1 = op.Slice(x, start1, end1, [3], [1]) | ||
x2 = op.Slice(x, start2, end2, [3], [1]) | ||
minus_x2 = op.Neg(x2) | ||
rotated_x = op.Concat(minus_x2, x1, axis=-1) | ||
return rotated_x | ||
|
||
|
||
class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): | ||
def pattern(self, op, x, cos, sin, start1, end1, start2, end2): | ||
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin | ||
|
||
def check(self, op, x, start1, end1, start2, end2, **_): | ||
# x needs to be a 4D tensor with known last dimension size (== head_size) | ||
if x is None or x.shape is None or len(x.shape) != 4: | ||
return False | ||
head_size = x.shape[3] | ||
if not isinstance(head_size, int): | ||
return False | ||
half_head_size = head_size // 2 | ||
|
||
# Check that x is being split into two equal halves of size half_head_size | ||
return ( | ||
_ir_utils.is_singleton_value(start1, 0) | ||
and _ir_utils.is_singleton_value(end1, half_head_size) | ||
and _ir_utils.is_singleton_value(start2, half_head_size) | ||
and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) | ||
) | ||
|
||
def rewrite(self, op, x, cos, sin, **_): | ||
return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion") | ||
|
||
|
||
_rule = RotaryEmbeddingFusion.rule() | ||
|
||
rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) | ||
|
||
|
||
def fuse_rotary_embedding(model: ir.Model) -> None: | ||
count = rotary_embedding_rules.apply_to_model(model) | ||
print(f"Rotary Embedding count: {count}") |
23 changes: 23 additions & 0 deletions
23
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import onnxscript.optimizer | ||
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | ||
|
||
|
||
class TestRotaryEmbedding(unittest.TestCase): | ||
def test_smollm(self): | ||
smollm_test = _SmollmTestData() | ||
model = smollm_test.get_onnx_model() | ||
onnxscript.optimizer.optimize(model) | ||
fuse_rotary_embedding(model) | ||
op_types = [n.op_type for n in model.graph] | ||
self.assertIn("RotaryEmbedding", op_types) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check failure
Code scanning / lintrunner
MYPY/arg-type Error