From 802fead7e6d03c41f6455505527eabc51effe707 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 12 Dec 2023 15:27:53 +0000 Subject: [PATCH] [Unity][Transform] Implement relax.transform.ExpandMatmulOfSum An optimization pass that rewrites `x*(A+B)` into `x*A + x*B`, where `x`, `A`, and `B` are relax tensors. --- python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 20 +++ src/relax/transform/expand_matmul_of_sum.cc | 96 +++++++++++++ .../test_transform_expand_matmul_of_sum.py | 127 ++++++++++++++++++ 4 files changed, 244 insertions(+) create mode 100644 src/relax/transform/expand_matmul_of_sum.cc create mode 100644 tests/python/relax/test_transform_expand_matmul_of_sum.py diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 19316c76b83dd..0d992ed7147fd 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -34,6 +34,7 @@ DecomposeOpsForInference, DecomposeOpsForTraining, EliminateCommonSubexpr, + ExpandMatmulOfSum, ExpandTupleArguments, FewShotTuning, FoldConstant, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 9589f661d79ec..36edba33fa60e 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1224,6 +1224,26 @@ def SplitCallTIRByPattern(patterns: List[PrimFunc], fcodegen: Callable) -> tvm.i return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore +def ExpandMatmulOfSum(): + """Expand `matmul(x, A+B)` to `matmul(x,A) + matmul(x,B)` + + If either operand can be fully computed at compile-time (only + depends on function parameters after kNumInput), this expansion is + suppressed. + + Useful for optimizing LoRA computations, where `matmul(x, Base + + LoraA*LoraB)` may be expanded to `matmul(x, Base) + matmul(x, + LoraA*LoraB)`, allowing it to optimized with `CombineParallelMatmul`. + + Returns + ------- + ret : tvm.transform.Pass + The corresponding pass. + """ + + return _ffi_api.ExpandMatmulOfSum() # type: ignore + + def CombineParallelMatmul(check=None): """Combine multiple matmul operators sharing the same LHS matrix into one, followed by slicing. When all matmul branches in a tree have the same set of fused ops, diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc new file mode 100644 index 0000000000000..76ebae94982d1 --- /dev/null +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/expand_matmul_of_sum.cc + * \brief Expand `matmul(x, A+B)` to `matmul(x, A) + matmul(x,B)` + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/tensor/binary.h" +#include "../op/tensor/linear_algebra.h" + +namespace tvm { +namespace relax { + +namespace { +std::tuple)>> CreatePatterns( + const Function& func) { + auto compile_time_arr = ComputableAtCompileTime(func); + std::unordered_set compile_time_lookup( + compile_time_arr.begin(), compile_time_arr.end()); + + auto pat_lhs = WildcardPattern(); + + auto pat_rhs_a = WildcardPattern(); + auto pat_rhs_b = WildcardPattern(); + auto pat_rhs = IsOp("relax.add")(pat_rhs_a, pat_rhs_b); + + auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs); + + auto rewriter = [=](Expr expr, Map matches) -> Expr { + auto lhs = matches[pat_lhs]; + auto rhs_a = matches[pat_rhs_a]; + auto rhs_b = matches[pat_rhs_b]; + + // Suppress the rewrite if `(A+B)` can be computed at + // compile-time. + auto is_compile_time = [&compile_time_lookup](Expr arg) -> bool { + if (auto as_var = arg.as()) { + return compile_time_lookup.count(as_var.value()); + } else { + return false; + } + }; + + if (is_compile_time(rhs_a) && is_compile_time(rhs_b)) { + return expr; + } + + return add(matmul(lhs, rhs_a, DataType::Void()), matmul(lhs, rhs_b, DataType::Void())); + }; + + return {pat_matmul, rewriter}; +} + +} // namespace + +namespace transform { +Pass ExpandMatmulOfSum() { + auto pass_func = [=](Function func, IRModule mod, PassContext pc) { + auto [pattern, rewriter] = CreatePatterns(func); + return RewriteCall(pattern, rewriter, func); + }; + return CreateFunctionPass(pass_func, 1, "ExpandMatmulOfSum", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ExpandMatmulOfSum").set_body_typed(ExpandMatmulOfSum); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_expand_matmul_of_sum.py b/tests/python/relax/test_transform_expand_matmul_of_sum.py new file mode 100644 index 0000000000000..67e59225c5ed6 --- /dev/null +++ b/tests/python/relax/test_transform_expand_matmul_of_sum.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import inspect + +import pytest + +import tvm.testing +from tvm import relax +from tvm.script import ir as I, relax as R + + +class Base: + def test_compare(self): + transform = relax.transform.ExpandMatmulOfSum() + + if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception): + with pytest.raises(self.Expected): + transform(self.Before) + else: + after = transform(self.Before) + tvm.ir.assert_structural_equal(self.Expected, after) + + +class TestSimple(Base): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16], "float32"), + A: R.Tensor([16, 32], "float32"), + B: R.Tensor([16, 32], "float32"), + ) -> R.Tensor([32], "float32"): + weight = R.add(A, B) + out = R.matmul(x, weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([16], "float32"), + A: R.Tensor([16, 32], "float32"), + B: R.Tensor([16, 32], "float32"), + ) -> R.Tensor([32], "float32"): + lhs = R.matmul(x, A) + rhs = R.matmul(x, B) + out = R.add(lhs, rhs) + return out + + +class TestNoExpansionOfCompileTimeAddition(Base): + """Do not expand compile-time parameters + + This expansion is primarily to prepare the function for a later + use of `CombineParallelMatmul`. If the addition can be performed + at compile-time, this is preferable. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16], "float32"), + A: R.Tensor([16, 32], "float32"), + B: R.Tensor([16, 32], "float32"), + ) -> R.Tensor([32], "float32"): + R.func_attr({"num_input": 1}) + weight = R.add(A, B) + out = R.matmul(x, weight) + return out + + Expected = Before + + +class TestExpansionOfRuntimeAddition(Base): + """Expand runtime addition + + This expansion is primarily to prepare the function for a later + use of `CombineParallelMatmul`. The expansion to `x*A + x*B` + should occur iff `A+B` is not computable at compile-time. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([16], "float32"), + A: R.Tensor([16, 32], "float32"), + B: R.Tensor([16, 32], "float32"), + ) -> R.Tensor([32], "float32"): + R.func_attr({"num_input": 2}) + weight = R.add(A, B) + out = R.matmul(x, weight) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([16], "float32"), + A: R.Tensor([16, 32], "float32"), + B: R.Tensor([16, 32], "float32"), + ) -> R.Tensor([32], "float32"): + R.func_attr({"num_input": 2}) + lhs = R.matmul(x, A) + rhs = R.matmul(x, B) + out = R.add(lhs, rhs) + return out + + +if __name__ == "__main__": + tvm.testing.main()