From 15e110e2aacd04515478b540d6534b704e238135 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 Aug 2022 19:47:41 -0500 Subject: [PATCH] [TIR] Add tir::builtin::undef (#12266) * [UnitTest] RemoveStoreUndef, simplest behavior * [RemoveStoreUndef] First implementation * [UnitTest] RemoveStoreUndef, stores that depend through LetStmt * [UnitTest] RemoveStoreUndef, LetStmt handling, error on illegal usage * [RemoveStoreUndef] Added error checking for illegal T.undef() usage * Fix lint error * Use const ref for list of stores to remove * Verify that removed expression has no other side effects * Fix lint error --- include/tvm/tir/builtin.h | 8 + python/tvm/tir/transform/transform.py | 11 ++ src/tir/op/builtin.cc | 4 + src/tir/transforms/remove_store_undef.cc | 179 ++++++++++++++++++ .../test_tir_transform_remove_undef.py | 94 +++++++++ 5 files changed, 296 insertions(+) create mode 100644 src/tir/transforms/remove_store_undef.cc create mode 100644 tests/python/unittest/test_tir_transform_remove_undef.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index fc326c18730e..12290a97c840 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -729,6 +729,14 @@ TVM_DLL const Op& mem_copy(); */ TVM_DLL const Op& assume(); +/*! + * \brief Returns an initialized but arbitrary value + * + * Compile-time representation of memory locations whose values may be + * altered as a result of optimizations. + */ +TVM_DLL const Op& undef(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index d63c65dfddde..eb2cff641ca3 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -264,6 +264,17 @@ def RemoveAssume(): return _ffi_api.RemoveAssume() # type: ignore +def RemoveStoreUndef(): + """Remove stores of undefined values from the Stmt. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RemoveStoreUndef() # type: ignore + + def BF16Legalize(): """Legalize bf16 typed Ops. Runs BF16Promote, BF16CastElimination and BF16TypeLowering diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 860f98dd1430..9642f8e39f39 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -292,6 +292,10 @@ TIR_DEFINE_BUILTIN_FUNC(assume) .set_attr("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo)) .set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(undef) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)) + .set_num_inputs(0); + } // namespace builtin } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/remove_store_undef.cc b/src/tir/transforms/remove_store_undef.cc new file mode 100644 index 000000000000..6b28cb165aa9 --- /dev/null +++ b/src/tir/transforms/remove_store_undef.cc @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file remove_store_undef.cc + * \brief Remove stores of tir::builtin::undef + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class StoreUndefLocator : public StmtExprVisitor { + public: + static std::unordered_set Locate(Stmt stmt) { + StoreUndefLocator locator; + locator(std::move(stmt)); + return locator.undef_stores_; + } + + private: + StoreUndefLocator() = default; + + void VisitStmt_(const BufferStoreNode* op) final { + bool stash_undef = false; + std::swap(has_undef_, stash_undef); + StmtExprVisitor::VisitExpr(op->value); + std::swap(has_undef_, stash_undef); + if (stash_undef) { + ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) + << "Error: T.undef() used in BufferStore expressions " + << "must not have other side effects"; + undef_stores_.insert(op); + } + } + + void VisitExpr_(const BufferLoadNode* op) final { + // This function left deliberately empty. builtin::undef() + // shouldn't occur in the indices of BufferLoad. Avoiding + // visiting the indices catches the builtin::undef in + // ValidateAllUndefRemoved. + } + + void VisitStmt_(const LetStmtNode* op) final { + bool stash_undef = false; + std::swap(has_undef_, stash_undef); + StmtExprVisitor::VisitExpr(op->value); + std::swap(has_undef_, stash_undef); + if (stash_undef) { + ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) + << "Error: T.undef() used in Let expressions " + << "must not have other side effects"; + var_bindings_with_undef_.insert(op->var.get()); + } + + StmtExprVisitor::VisitStmt(op->body); + } + + void VisitExpr_(const VarNode* op) final { + if (var_bindings_with_undef_.count(op)) { + has_undef_ = true; + } + } + + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::undef())) { + has_undef_ = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + bool has_undef_{false}; + + std::unordered_set var_bindings_with_undef_; + std::unordered_set undef_stores_; +}; + +// Remove any BufferStores whose value depends on T.undef +class StoreUndefRemover : public StmtExprMutator { + public: + static Stmt Apply(Stmt stmt) { + auto to_remove = StoreUndefLocator::Locate(stmt); + StoreUndefRemover mutator(to_remove); + return mutator(std::move(stmt)); + } + + private: + using Parent = StmtExprMutator; + + explicit StoreUndefRemover(const std::unordered_set& to_remove) + : to_remove_(to_remove) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + if (to_remove_.count(op)) { + return Evaluate(0); + } else { + return Parent::VisitStmt_(op); + } + } + + const std::unordered_set& to_remove_; +}; + +// Remove any BufferStores whose value depends on T.undef +class ContainsUndefChecker : public StmtExprVisitor { + public: + static bool Check(const Stmt& stmt) { + ContainsUndefChecker checker; + checker(stmt); + return checker.contains_undef; + } + + private: + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::undef())) { + contains_undef = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + bool contains_undef{false}; +}; + +namespace transform { +Pass RemoveStoreUndefInternal() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = StoreUndefRemover::Apply(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RemoveStoreUndefInternal", {}); +} + +Pass ValidateAllUndefRemoved() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + bool contains_undef = ContainsUndefChecker::Check(f->body); + ICHECK(!contains_undef) << "Expected removal of BufferStore containing builtin::undef() " + << "to remove all instances of builtin::undef(). " + << "Instead, result was" + << "\n" + << f; + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ValidateAllUndefRemoved", {}); +} + +Pass RemoveStoreUndef() { + return Sequential({RemoveStoreUndefInternal(), RemoveNoOp(), ValidateAllUndefRemoved()}, + "tir.RemoveStoreUndef"); +} + +TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_remove_undef.py b/tests/python/unittest/test_tir_transform_remove_undef.py new file mode 100644 index 000000000000..c634bf5e9da8 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_remove_undef.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import tir as T +from tvm import TVMError + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + @tvm.testing.fixture + def transform(self): + return tvm.tir.transform.RemoveStoreUndef() + + +class TestRemoveStoreUndef(BaseBeforeAfter): + """Remove a store whose value is T.undef()""" + + def before(A: T.Buffer[1, "int32"]): + A[0] = T.undef(dtype="int32") + + def expected(A: T.Buffer[1, "int32"]): + T.evaluate(0) + + +class TestRemoveStoreUndefExpression(BaseBeforeAfter): + """Expressions containing T.undef() are removed""" + + def before(A: T.Buffer[1, "int32"]): + A[0] = 1 + T.undef(dtype="int32") + + def expected(A: T.Buffer[1, "int32"]): + T.evaluate(0) + + +class TestKeepOtherCallNodes(BaseBeforeAfter): + """Expressions containing other CallNodes are not removed""" + + def before(A: T.Buffer[1, "int32"], n: T.int32): + A[0] = T.shift_left(n, 1, dtype="int32") + + expected = before + + +class TestRemoveLetUndef(BaseBeforeAfter): + """Remove a store whose value is bound to T.undef()""" + + def before(A: T.Buffer[1, "int32"]): + val = T.undef(dtype="int32") + A[0] = val + + def expected(A: T.Buffer[1, "int32"]): + T.evaluate(0) + + +class TestRaiseErrorForUndefAsStoreIndices(BaseBeforeAfter): + """Use of T.undef() as buffer indices is an error""" + + def before(A: T.Buffer[1, "int32"]): + val = T.undef(dtype="int32") + A[val] = 5 + + expected = TVMError + + +class TestRaiseErrorForUndefAsLoadIndices(BaseBeforeAfter): + """Use of T.undef() as buffer indices is an error + + Even though this occurs as part of the BufferStore's value, the + T.undef() may not appear in a buffer's indices. + """ + + def before(A: T.Buffer[1, "int32"], B: T.Buffer[1, "int32"]): + B[0] = A[T.undef(dtype="int32")] + + expected = TVMError + + +if __name__ == "__main__": + tvm.testing.main()