-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft - WIP] Add rotary embedding fusion rule (part 1) #1981
base: main
Are you sure you want to change the base?
Conversation
❌ 43 Tests Failed:
View the top 3 failed tests by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
scalar = get_singleton_value(val) | ||
if scalar is None: | ||
return False | ||
if isinstance(expected, Callable): |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
if scalar is None: | ||
return False | ||
if isinstance(expected, Callable): | ||
return expected(scalar) |
Check failure
Code scanning / lintrunner
MYPY/operator Error
if scalar is None: | ||
return False | ||
if isinstance(expected, Callable): | ||
return expected(scalar) |
Check failure
Code scanning / lintrunner
MYPY/operator Error
return expected == scalar | ||
# rtol must be specified for float comparison | ||
assert rtol is not None | ||
return math.isclose(scalar, expected, rtol=rtol) |
Check failure
Code scanning / lintrunner
MYPY/call-arg Error
return expected == scalar | ||
# rtol must be specified for float comparison | ||
assert rtol is not None | ||
return math.isclose(scalar, expected, rtol=rtol) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
return match.fail("Matched nodes have other uses preventing replacement.") | ||
|
||
match.outputs.extend(output_values) | ||
return match | ||
|
||
def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: | ||
def _multi_match(self, candidate: Iterable[ir.Node], check_removable: bool) -> MatchResult: |
Check warning
Code scanning / lintrunner
RUFF/D417 Warning
See https://docs.astral.sh/ruff/rules/undocumented-param
def __init__(self, name: str, max_pos_id: int): | ||
super().__init__(name) | ||
self._max_pos_id = max_pos_id | ||
self.remove_nodes = False |
Check warning
Code scanning / CodeQL
Overwriting attribute in super-class or sub-class Warning
Initial version of fusion for rotary embedding.
Limitations and TODO: