Skip to content

Commit

Permalink
[Unity][Transform] Implement relax.transform.ExpandMatmulOfSum
Browse files Browse the repository at this point in the history
An optimization pass that rewrites `x*(A+B)` into `x*A + x*B`, where
`x`, `A`, and `B` are relax tensors.
  • Loading branch information
Lunderberg committed Dec 29, 2023
1 parent 703aa15 commit 802fead
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DecomposeOpsForInference,
DecomposeOpsForTraining,
EliminateCommonSubexpr,
ExpandMatmulOfSum,
ExpandTupleArguments,
FewShotTuning,
FoldConstant,
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,26 @@ def SplitCallTIRByPattern(patterns: List[PrimFunc], fcodegen: Callable) -> tvm.i
return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore


def ExpandMatmulOfSum():
"""Expand `matmul(x, A+B)` to `matmul(x,A) + matmul(x,B)`
If either operand can be fully computed at compile-time (only
depends on function parameters after kNumInput), this expansion is
suppressed.
Useful for optimizing LoRA computations, where `matmul(x, Base +
LoraA*LoraB)` may be expanded to `matmul(x, Base) + matmul(x,
LoraA*LoraB)`, allowing it to optimized with `CombineParallelMatmul`.
Returns
-------
ret : tvm.transform.Pass
The corresponding pass.
"""

return _ffi_api.ExpandMatmulOfSum() # type: ignore


def CombineParallelMatmul(check=None):
"""Combine multiple matmul operators sharing the same LHS matrix into one,
followed by slicing. When all matmul branches in a tree have the same set of fused ops,
Expand Down
96 changes: 96 additions & 0 deletions src/relax/transform/expand_matmul_of_sum.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/transform/expand_matmul_of_sum.cc
* \brief Expand `matmul(x, A+B)` to `matmul(x, A) + matmul(x,B)`
*/

#include <tvm/relax/analysis.h>
#include <tvm/relax/dataflow_matcher.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

#include <optional>
#include <unordered_set>
#include <vector>

#include "../op/tensor/binary.h"
#include "../op/tensor/linear_algebra.h"

namespace tvm {
namespace relax {

namespace {
std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreatePatterns(
const Function& func) {
auto compile_time_arr = ComputableAtCompileTime(func);
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> compile_time_lookup(
compile_time_arr.begin(), compile_time_arr.end());

auto pat_lhs = WildcardPattern();

auto pat_rhs_a = WildcardPattern();
auto pat_rhs_b = WildcardPattern();
auto pat_rhs = IsOp("relax.add")(pat_rhs_a, pat_rhs_b);

auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs);

auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
auto lhs = matches[pat_lhs];
auto rhs_a = matches[pat_rhs_a];
auto rhs_b = matches[pat_rhs_b];

// Suppress the rewrite if `(A+B)` can be computed at
// compile-time.
auto is_compile_time = [&compile_time_lookup](Expr arg) -> bool {
if (auto as_var = arg.as<Var>()) {
return compile_time_lookup.count(as_var.value());
} else {
return false;
}
};

if (is_compile_time(rhs_a) && is_compile_time(rhs_b)) {
return expr;
}

return add(matmul(lhs, rhs_a, DataType::Void()), matmul(lhs, rhs_b, DataType::Void()));
};

return {pat_matmul, rewriter};
}

} // namespace

namespace transform {
Pass ExpandMatmulOfSum() {
auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
auto [pattern, rewriter] = CreatePatterns(func);
return RewriteCall(pattern, rewriter, func);
};
return CreateFunctionPass(pass_func, 1, "ExpandMatmulOfSum", {});
}

TVM_REGISTER_GLOBAL("relax.transform.ExpandMatmulOfSum").set_body_typed(ExpandMatmulOfSum);

} // namespace transform
} // namespace relax
} // namespace tvm
127 changes: 127 additions & 0 deletions tests/python/relax/test_transform_expand_matmul_of_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import inspect

import pytest

import tvm.testing
from tvm import relax
from tvm.script import ir as I, relax as R


class Base:
def test_compare(self):
transform = relax.transform.ExpandMatmulOfSum()

if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception):
with pytest.raises(self.Expected):
transform(self.Before)
else:
after = transform(self.Before)
tvm.ir.assert_structural_equal(self.Expected, after)


class TestSimple(Base):
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16], "float32"),
A: R.Tensor([16, 32], "float32"),
B: R.Tensor([16, 32], "float32"),
) -> R.Tensor([32], "float32"):
weight = R.add(A, B)
out = R.matmul(x, weight)
return out

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([16], "float32"),
A: R.Tensor([16, 32], "float32"),
B: R.Tensor([16, 32], "float32"),
) -> R.Tensor([32], "float32"):
lhs = R.matmul(x, A)
rhs = R.matmul(x, B)
out = R.add(lhs, rhs)
return out


class TestNoExpansionOfCompileTimeAddition(Base):
"""Do not expand compile-time parameters
This expansion is primarily to prepare the function for a later
use of `CombineParallelMatmul`. If the addition can be performed
at compile-time, this is preferable.
"""

@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16], "float32"),
A: R.Tensor([16, 32], "float32"),
B: R.Tensor([16, 32], "float32"),
) -> R.Tensor([32], "float32"):
R.func_attr({"num_input": 1})
weight = R.add(A, B)
out = R.matmul(x, weight)
return out

Expected = Before


class TestExpansionOfRuntimeAddition(Base):
"""Expand runtime addition
This expansion is primarily to prepare the function for a later
use of `CombineParallelMatmul`. The expansion to `x*A + x*B`
should occur iff `A+B` is not computable at compile-time.
"""

@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16], "float32"),
A: R.Tensor([16, 32], "float32"),
B: R.Tensor([16, 32], "float32"),
) -> R.Tensor([32], "float32"):
R.func_attr({"num_input": 2})
weight = R.add(A, B)
out = R.matmul(x, weight)
return out

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([16], "float32"),
A: R.Tensor([16, 32], "float32"),
B: R.Tensor([16, 32], "float32"),
) -> R.Tensor([32], "float32"):
R.func_attr({"num_input": 2})
lhs = R.matmul(x, A)
rhs = R.matmul(x, B)
out = R.add(lhs, rhs)
return out


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 802fead

Please sign in to comment.