-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] First version of fusion optimizations for transformers #1938
Conversation
❌ 18 Tests Failed:
View the top 1 failed tests by shortest run time
View the full list of 2 ❄️ flaky tests
To view more test analytics, go to the Test Analytics Dashboard |
The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence) | ||
|
||
The dot-product attention is then computed using SDPA | ||
|
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence) | ||
|
||
The dot-product attention is then computed using SDPA | ||
|
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace
|
||
|
||
def _skip_normalization(op, input, skip, gamma, epsilon, stash_type): | ||
normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning
See https://docs.astral.sh/ruff/rules/unused-variable
|
||
|
||
def _skip_normalization(op, input, skip, gamma, epsilon, stash_type): | ||
normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning
See https://docs.astral.sh/ruff/rules/unused-variable
@@ -0,0 +1,38 @@ | |||
# Copyright (c) Microsoft Corporation. |
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning
@@ -0,0 +1,152 @@ | |||
# Copyright (c) Microsoft Corporation. |
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning
Extract the independent optimization/refinements from [the fusion PR](#1938) as a separate PR, ready to be reviewed/merged. (The fusion work is still WIP.) * Replace Expand by Identity when applicable (in core optimization) * Cleanup Dropout Identity replacement in the case when Dropout has mask output * Make repeated (redundant) call to inliner efficient
visit(x, 0) | ||
elif isinstance(x, ir.Value): | ||
if backward and x.producer() is not None: | ||
visit(x.producer(), 0) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
@@ -0,0 +1,73 @@ | |||
# Copyright (c) Microsoft Corporation. |
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning
op, input, query_weight, key_weight, value_weight, cos, sin | ||
): | ||
"""Variation of first pattern with Reshape omitted.""" | ||
query = _project_transpose_head(op, input, query_weight) |
Check failure
Code scanning / lintrunner
MYPY/call-arg Error
"""Variation of first pattern with Reshape omitted.""" | ||
query = _project_transpose_head(op, input, query_weight) | ||
query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") | ||
key = _project_transpose_head(op, input, key_weight) |
Check failure
Code scanning / lintrunner
MYPY/call-arg Error
# Reshape omitted here. | ||
key_transposed = op.Transpose(key_rope) | ||
# Reshape omitted here | ||
value = _project_transpose_head(op, input, value_weight) |
Check failure
Code scanning / lintrunner
MYPY/call-arg Error
@@ -0,0 +1,141 @@ | |||
# Copyright (c) Microsoft Corporation. |
Check warning
Code scanning / lintrunner
RUFF/format Warning
@@ -0,0 +1,141 @@ | |||
# Copyright (c) Microsoft Corporation. |
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning
@@ -0,0 +1,141 @@ | |||
# Copyright (c) Microsoft Corporation. | |||
# Licensed under the MIT License. | |||
from __future__ import annotations |
Check warning
Code scanning / lintrunner
RUFF/I001 Warning
See https://docs.astral.sh/ruff/rules/unsorted-imports
def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Iterable[str]) -> bool: | ||
if val.shape is None: | ||
return False | ||
if val.shape.rank() != len(shape): |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error
return False | ||
for actual, expected in zip(val.shape, shape): | ||
if expected not in bindings: | ||
bindings[expected] = actual |
Check failure
Code scanning / lintrunner
MYPY/assignment Error
Closing this for now. Will recreate as multiple separate PRs. |
Still TODO: