Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft - WIP] Add rotary embedding fusion rule (part 1) #1981

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

gramalingam
Copy link
Collaborator

Initial version of fusion for rotary embedding.

Limitations and TODO:

  • More pattern variations to be added
  • Hard to test running fused model against ORT, since it requires new variation of RotaryEmbedding op.

Copy link

codecov bot commented Dec 18, 2024

❌ 43 Tests Failed:

Tests completed Failed Passed Skipped
8372 43 8329 3751
View the top 3 failed tests by shortest run time
::onnxscript.rewriter.onnxruntime.xformers._test_models
Stack Traces | 0s run time
No failure message available
onnxscript.rewriter.generic_pattern_test.GenericPatternTest_0::test_transpose_transpose_onnxscript
Stack Traces | 0.001s run time
onnxscript/rewriter/generic_pattern_test.py:594: in test_transpose_transpose_onnxscript
    rule.apply_to_model(ir_model)
onnxscript/rewriter/pattern.py:1358: in apply_to_model
    return RewriteRuleSet([self], commute=commute).apply_to_model(model, verbose=verbose)
onnxscript/rewriter/pattern.py:1517: in apply_to_model
    count = self._apply_to_graph_or_function(model, model.graph, verbose=verbose)
onnxscript/rewriter/pattern.py:1493: in _apply_to_graph_or_function
    delta = rule.try_rewrite(model, graph_or_function, node, verbose=verbose)
onnxscript/rewriter/pattern.py:1328: in try_rewrite
    match = self._matcher.match(
E   TypeError: GenericPatternMatcher.match() got an unexpected keyword argument 'remove_nodes'
onnxscript.rewriter.generic_pattern_test.GenericPatternTest_0::test_shared_root_value_test
Stack Traces | 0.002s run time
onnxscript/rewriter/generic_pattern_test.py:282: in test_shared_root_value_test
    rule.apply_to_model(ir_model)
onnxscript/rewriter/pattern.py:1358: in apply_to_model
    return RewriteRuleSet([self], commute=commute).apply_to_model(model, verbose=verbose)
onnxscript/rewriter/pattern.py:1517: in apply_to_model
    count = self._apply_to_graph_or_function(model, model.graph, verbose=verbose)
onnxscript/rewriter/pattern.py:1493: in _apply_to_graph_or_function
    delta = rule.try_rewrite(model, graph_or_function, node, verbose=verbose)
onnxscript/rewriter/pattern.py:1328: in try_rewrite
    match = self._matcher.match(
E   TypeError: GenericPatternMatcher.match() got an unexpected keyword argument 'remove_nodes'

To view more test analytics, go to the Test Analytics Dashboard
📢 Thoughts on this report? Let us know!

scalar = get_singleton_value(val)
if scalar is None:
return False
if isinstance(expected, Callable):

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "isinstance" has incompatible type ""; expected "_ClassInfo" To disable, use # type: ignore[arg-type]
if scalar is None:
return False
if isinstance(expected, Callable):
return expected(scalar)

Check failure

Code scanning / lintrunner

MYPY/operator Error

"float" not callable To disable, use # type: ignore[operator]
if scalar is None:
return False
if isinstance(expected, Callable):
return expected(scalar)

Check failure

Code scanning / lintrunner

MYPY/operator Error

"int" not callable To disable, use # type: ignore[operator]
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
@gramalingam gramalingam changed the title Add rotary embedding fusion rule (part 1) [Draft - WIP] Add rotary embedding fusion rule (part 1) Dec 20, 2024
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rtol=rtol)

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Unexpected keyword argument "rtol" for "isclose" To disable, use # type: ignore[call-arg]
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rtol=rtol)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "isclose" has incompatible type "float | Callable[..., Any]"; expected "SupportsFloat | SupportsIndex" To disable, use # type: ignore[arg-type]
return match.fail("Matched nodes have other uses preventing replacement.")

match.outputs.extend(output_values)
return match

def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult:
def _multi_match(self, candidate: Iterable[ir.Node], check_removable: bool) -> MatchResult:

Check warning

Code scanning / lintrunner

RUFF/D417 Warning

Missing argument description in the docstring for \_multi\_match: check\_removable.
See https://docs.astral.sh/ruff/rules/undocumented-param
def __init__(self, name: str, max_pos_id: int):
super().__init__(name)
self._max_pos_id = max_pos_id
self.remove_nodes = False

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute remove_nodes, which was previously defined in superclass
RewriteRuleClassBase
.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

2 participants