Skip to content

Commit

Permalink
[TIR] Add tir::builtin::assume (#12267)
Browse files Browse the repository at this point in the history
* [RemoveAssume] Implemented T.assume in TVMScript, RemoveAssume

* [UnitTest] RemoveAssume, initial functionality tests
  • Loading branch information
Lunderberg authored Aug 5, 2022
1 parent ca46f21 commit 8a0911c
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,15 @@ TVM_DLL const Op& texture2d_load();
*/
TVM_DLL const Op& mem_copy();

/*!
* \brief Provide a true statement that can be used for simplifications
*
* Compile-time representation of known constraints about function
* inputs. This assumption is removed when lowering, and does not
* occur in codegen.
*/
TVM_DLL const Op& assume();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@ def store(var, index, value, predicate=True, span=None):
super().__init__(store, stmt=True)


@register
class AssumeIntrin(Intrin):
def __init__(self):
def assume(constraint, span):
return tvm.tir.Evaluate(
tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span)
)

super().__init__(assume, stmt=True)


@register
def comm_reducer(lambda_io, identities, span):
"""Create a CommReducer from lambda inputs/outputs and the identities"""
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,17 @@ def RemoveNoOp():
return _ffi_api.RemoveNoOp() # type: ignore


def RemoveAssume():
"""Remove all instances of builtin::assume
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.RemoveAssume() # type: ignore


def BF16Legalize():
"""Legalize bf16 typed Ops.
Runs BF16Promote, BF16CastElimination and BF16TypeLowering
Expand Down
8 changes: 8 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,14 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
}

Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
if (auto* call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::assume())) {
Doc doc;
doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")";
return doc;
}
}

Doc doc;
doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
return doc;
Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(assume)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
.set_num_inputs(1);

} // namespace builtin
} // namespace tir
} // namespace tvm
69 changes: 69 additions & 0 deletions src/tir/transforms/remove_assume.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file remove_store_undef.cc
* \brief Remove stores of tir::builtin::undef
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

// Remove any builtin::assume calls
class AssumeRemover : public StmtExprMutator {
public:
using Parent = StmtExprMutator;

Stmt VisitStmt_(const EvaluateNode* op) final {
if (auto* call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::assume())) {
return Evaluate(0);
}
}
return StmtExprMutator::VisitStmt_(op);
}
};

namespace transform {
Pass RemoveAssumeInternal() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = AssumeRemover()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveAssumeInternal", {});
}

Pass RemoveAssume() {
return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume");
}

TVM_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume);

} // namespace transform

} // namespace tir
} // namespace tvm
57 changes: 57 additions & 0 deletions tests/python/unittest/test_tir_transform_remove_assume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import tvm.testing
from tvm.script import tir as T
from tvm import TVMError


class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
@tvm.testing.fixture
def transform(self):
return tvm.tir.transform.RemoveAssume()


class TestRemoveAssume(BaseBeforeAfter):
"""Remove any instance of T.assume"""

def before(A: T.Buffer[1, "int32"]):
T.assume(A[0] == 5)
A[0] = 10

def expected(A: T.Buffer[1, "int32"]):
A[0] = 10


class TestRemoveAssumeLoop(BaseBeforeAfter):
"""Loops containing only T.assume should be removed"""

def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
T.assume(A[i] == 0)

for i in T.serial(16):
A[i] = 10

def expected(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
A[i] = 10


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 8a0911c

Please sign in to comment.