Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] First version of fusion optimizations for transformers #1938

Closed
wants to merge 34 commits into from

Conversation

gramalingam
Copy link
Collaborator

  • Introduce fusion rules for SdpaAttention, RMS Normalization, Skip Normalization, Rotary Embedding, and Multi Head Attention
  • Replace Expand by Identity when applicable (in core optimization)
  • Cleanup Dropout Identity replacement in the case when Dropout has mask output
  • Make repeated (redundant) call to inliner efficient

Still TODO:

  • Multi Head Attention requires extra validation conditions
  • Need to cleanup use of "local" sub-patterns

@gramalingam gramalingam marked this pull request as draft November 9, 2024 01:23
Copy link

codecov bot commented Nov 9, 2024

❌ 18 Tests Failed:

Tests completed Failed Passed Skipped
10050 18 10032 3781
View the top 1 failed tests by shortest run time
::onnxscript.rewriter.onnxruntime.xformers._optimize_transformers_test
Stack Traces | 0s run time
No failure message available
View the full list of 2 ❄️ flaky tests
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_attribute_by_positional_args

Flake rate in main: 39.74% (Passed 11239 times, Failed 7413 times)

Stack Traces | 0.003s run time
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:91: in run
    res = self._run(x, y)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests\eager_mode_test.py:112: in test_function_attribute_by_positional_args
    self.assertEqual(add_with_alpha(1.0, 2.0, 3.0), 7.0)
onnxscript\values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript\evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests\eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
onnxscript\onnx_opset\_impl\opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript\values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript\evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript\evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\reference_evaluator.py:599: in run
    outputs = node.run(*inputs, **linked_attributes)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').
tests.eager_mode_test.TestEagerModeArguments_0_reference_runtime::test_function_input_and_attribute_by_kwargs_out_of_order

Flake rate in main: 39.74% (Passed 11239 times, Failed 7413 times)

Stack Traces | 0.004s run time
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:91: in run
    res = self._run(x, y)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:139: in _run
    res = (convert_from_ml_dtypes(res[0]),)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\custom_element_types.py:50: in convert_from_ml_dtypes
    return array.view(dtype=dtype)
E   ValueError: Changing the dtype of a 0d array is only supported if the itemsize is unchanged

The above exception was the direct cause of the following exception:
tests\eager_mode_test.py:115: in test_function_input_and_attribute_by_kwargs_out_of_order
    self.assertEqual(add_with_alpha(alpha=3.0, other=2.0, this=1.0), 7.0)
onnxscript\values.py:576: in __call__
    return evaluator.default().eval_function(self, args, kwargs)
onnxscript\evaluator.py:307: in eval_function
    result = function.function(*adapted_args, **adapted_kwargs)
tests\eager_mode_test.py:59: in add_with_alpha
    other = op.Mul(other, alpha)
onnxscript\onnx_opset\_impl\opset14.py:696: in Mul
    return op(*self._prepare_inputs(schema, A, B))
onnxscript\values.py:304: in __call__
    return evaluator.default().eval(schema, args, kwargs)
onnxscript\evaluator.py:194: in eval
    outputs = self._eval(schema, inputs, attributes, closure)
onnxscript\evaluator.py:524: in _eval
    result = session.run(None, session_run_input)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\reference_evaluator.py:599: in run
    outputs = node.run(*inputs, **linked_attributes)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:114: in run
    res = OpRunBinary.run(self, x, y)
.nox\test_torch_nightly\Lib\site-packages\onnx\reference\ops\_op.py:93: in run
    raise TypeError(
E   TypeError: Issues with types <class 'numpy.ndarray'>, <class 'numpy.ndarray'> (binary operator 'Mul').

To view more test analytics, go to the Test Analytics Dashboard
📢 Thoughts on this report? Let us know!

The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)

The dot-product attention is then computed using SDPA

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)

The dot-product attention is then computed using SDPA

Check warning

Code scanning / lintrunner

RUFF/W293 Warning



def _skip_normalization(op, input, skip, gamma, epsilon, stash_type):
normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization(

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable mean is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable


def _skip_normalization(op, input, skip, gamma, epsilon, stash_type):
normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization(

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable inv\_std\_var is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
@titaiwangms titaiwangms self-requested a review November 12, 2024 17:56
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
@@ -0,0 +1,152 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
gramalingam added a commit that referenced this pull request Nov 15, 2024
Extract the independent optimization/refinements from [the fusion
PR](#1938) as a separate PR,
ready to be reviewed/merged. (The fusion work is still WIP.)

* Replace Expand by Identity when applicable (in core optimization)
* Cleanup Dropout Identity replacement in the case when Dropout has mask
output
* Make repeated (redundant) call to inliner efficient
@justinchuby
Copy link
Collaborator

if backward:
for inp in node.inputs:
if inp is not None and inp.producer() is not None:
visit(inp.producer(), depth + 1)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "visit" has incompatible type "Node | None"; expected "Node" To disable, use # type: ignore[arg-type]
onnxscript/rewriter/_ir_utils.py Fixed Show fixed Hide fixed
visit(x, 0)
elif isinstance(x, ir.Value):
if backward and x.producer() is not None:
visit(x.producer(), 0)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "visit" has incompatible type "Node | None"; expected "Node" To disable, use # type: ignore[arg-type]
@@ -0,0 +1,73 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
op, input, query_weight, key_weight, value_weight, cos, sin
):
"""Variation of first pattern with Reshape omitted."""
query = _project_transpose_head(op, input, query_weight)

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Missing positional argument "reshape_var" in call to "_project_transpose_head" To disable, use # type: ignore[call-arg]
"""Variation of first pattern with Reshape omitted."""
query = _project_transpose_head(op, input, query_weight)
query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local")
key = _project_transpose_head(op, input, key_weight)

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Missing positional argument "reshape_var" in call to "_project_transpose_head" To disable, use # type: ignore[call-arg]
# Reshape omitted here.
key_transposed = op.Transpose(key_rope)
# Reshape omitted here
value = _project_transpose_head(op, input, value_weight)

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Missing positional argument "reshape_var" in call to "_project_transpose_head" To disable, use # type: ignore[call-arg]
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

Check warning

Code scanning / lintrunner

RUFF/I001 Warning

Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports
def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Iterable[str]) -> bool:
if val.shape is None:
return False
if val.shape.rank() != len(shape):

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "len" has incompatible type "Iterable[str]"; expected "Sized" To disable, use # type: ignore[arg-type]
return False
for actual, expected in zip(val.shape, shape):
if expected not in bindings:
bindings[expected] = actual

Check failure

Code scanning / lintrunner

MYPY/assignment Error

Incompatible types in assignment (expression has type "int | SymbolicDim", target has type "int") To disable, use # type: ignore[assignment]
@gramalingam
Copy link
Collaborator Author

Closing this for now. Will recreate as multiple separate PRs.

@gramalingam gramalingam closed this Dec 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

2 participants