This repository has been archived by the owner on Nov 25, 2022. It is now read-only.
forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR] Add tir::builtin::undef (apache#12266)
* [UnitTest] RemoveStoreUndef, simplest behavior * [RemoveStoreUndef] First implementation * [UnitTest] RemoveStoreUndef, stores that depend through LetStmt * [UnitTest] RemoveStoreUndef, LetStmt handling, error on illegal usage * [RemoveStoreUndef] Added error checking for illegal T.undef() usage * Fix lint error * Use const ref for list of stores to remove * Verify that removed expression has no other side effects * Fix lint error
- Loading branch information
1 parent
d699670
commit 15e110e
Showing
5 changed files
with
296 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file remove_store_undef.cc | ||
* \brief Remove stores of tir::builtin::undef | ||
*/ | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/tir/analysis.h> | ||
#include <tvm/tir/builtin.h> | ||
#include <tvm/tir/op.h> | ||
#include <tvm/tir/stmt.h> | ||
#include <tvm/tir/stmt_functor.h> | ||
#include <tvm/tir/transform.h> | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
class StoreUndefLocator : public StmtExprVisitor { | ||
public: | ||
static std::unordered_set<const BufferStoreNode*> Locate(Stmt stmt) { | ||
StoreUndefLocator locator; | ||
locator(std::move(stmt)); | ||
return locator.undef_stores_; | ||
} | ||
|
||
private: | ||
StoreUndefLocator() = default; | ||
|
||
void VisitStmt_(const BufferStoreNode* op) final { | ||
bool stash_undef = false; | ||
std::swap(has_undef_, stash_undef); | ||
StmtExprVisitor::VisitExpr(op->value); | ||
std::swap(has_undef_, stash_undef); | ||
if (stash_undef) { | ||
ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) | ||
<< "Error: T.undef() used in BufferStore expressions " | ||
<< "must not have other side effects"; | ||
undef_stores_.insert(op); | ||
} | ||
} | ||
|
||
void VisitExpr_(const BufferLoadNode* op) final { | ||
// This function left deliberately empty. builtin::undef() | ||
// shouldn't occur in the indices of BufferLoad. Avoiding | ||
// visiting the indices catches the builtin::undef in | ||
// ValidateAllUndefRemoved. | ||
} | ||
|
||
void VisitStmt_(const LetStmtNode* op) final { | ||
bool stash_undef = false; | ||
std::swap(has_undef_, stash_undef); | ||
StmtExprVisitor::VisitExpr(op->value); | ||
std::swap(has_undef_, stash_undef); | ||
if (stash_undef) { | ||
ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) | ||
<< "Error: T.undef() used in Let expressions " | ||
<< "must not have other side effects"; | ||
var_bindings_with_undef_.insert(op->var.get()); | ||
} | ||
|
||
StmtExprVisitor::VisitStmt(op->body); | ||
} | ||
|
||
void VisitExpr_(const VarNode* op) final { | ||
if (var_bindings_with_undef_.count(op)) { | ||
has_undef_ = true; | ||
} | ||
} | ||
|
||
void VisitExpr_(const CallNode* op) final { | ||
if (op->op.same_as(builtin::undef())) { | ||
has_undef_ = true; | ||
} | ||
StmtExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
bool has_undef_{false}; | ||
|
||
std::unordered_set<const VarNode*> var_bindings_with_undef_; | ||
std::unordered_set<const BufferStoreNode*> undef_stores_; | ||
}; | ||
|
||
// Remove any BufferStores whose value depends on T.undef | ||
class StoreUndefRemover : public StmtExprMutator { | ||
public: | ||
static Stmt Apply(Stmt stmt) { | ||
auto to_remove = StoreUndefLocator::Locate(stmt); | ||
StoreUndefRemover mutator(to_remove); | ||
return mutator(std::move(stmt)); | ||
} | ||
|
||
private: | ||
using Parent = StmtExprMutator; | ||
|
||
explicit StoreUndefRemover(const std::unordered_set<const BufferStoreNode*>& to_remove) | ||
: to_remove_(to_remove) {} | ||
|
||
Stmt VisitStmt_(const BufferStoreNode* op) final { | ||
if (to_remove_.count(op)) { | ||
return Evaluate(0); | ||
} else { | ||
return Parent::VisitStmt_(op); | ||
} | ||
} | ||
|
||
const std::unordered_set<const BufferStoreNode*>& to_remove_; | ||
}; | ||
|
||
// Remove any BufferStores whose value depends on T.undef | ||
class ContainsUndefChecker : public StmtExprVisitor { | ||
public: | ||
static bool Check(const Stmt& stmt) { | ||
ContainsUndefChecker checker; | ||
checker(stmt); | ||
return checker.contains_undef; | ||
} | ||
|
||
private: | ||
void VisitExpr_(const CallNode* op) final { | ||
if (op->op.same_as(builtin::undef())) { | ||
contains_undef = true; | ||
} | ||
StmtExprVisitor::VisitExpr_(op); | ||
} | ||
|
||
bool contains_undef{false}; | ||
}; | ||
|
||
namespace transform { | ||
Pass RemoveStoreUndefInternal() { | ||
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { | ||
auto* n = f.CopyOnWrite(); | ||
n->body = StoreUndefRemover::Apply(std::move(n->body)); | ||
return f; | ||
}; | ||
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveStoreUndefInternal", {}); | ||
} | ||
|
||
Pass ValidateAllUndefRemoved() { | ||
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { | ||
bool contains_undef = ContainsUndefChecker::Check(f->body); | ||
ICHECK(!contains_undef) << "Expected removal of BufferStore containing builtin::undef() " | ||
<< "to remove all instances of builtin::undef(). " | ||
<< "Instead, result was" | ||
<< "\n" | ||
<< f; | ||
return f; | ||
}; | ||
return CreatePrimFuncPass(pass_func, 0, "tir.ValidateAllUndefRemoved", {}); | ||
} | ||
|
||
Pass RemoveStoreUndef() { | ||
return Sequential({RemoveStoreUndefInternal(), RemoveNoOp(), ValidateAllUndefRemoved()}, | ||
"tir.RemoveStoreUndef"); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef); | ||
|
||
} // namespace transform | ||
|
||
} // namespace tir | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import tvm | ||
import tvm.testing | ||
from tvm.script import tir as T | ||
from tvm import TVMError | ||
|
||
|
||
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): | ||
@tvm.testing.fixture | ||
def transform(self): | ||
return tvm.tir.transform.RemoveStoreUndef() | ||
|
||
|
||
class TestRemoveStoreUndef(BaseBeforeAfter): | ||
"""Remove a store whose value is T.undef()""" | ||
|
||
def before(A: T.Buffer[1, "int32"]): | ||
A[0] = T.undef(dtype="int32") | ||
|
||
def expected(A: T.Buffer[1, "int32"]): | ||
T.evaluate(0) | ||
|
||
|
||
class TestRemoveStoreUndefExpression(BaseBeforeAfter): | ||
"""Expressions containing T.undef() are removed""" | ||
|
||
def before(A: T.Buffer[1, "int32"]): | ||
A[0] = 1 + T.undef(dtype="int32") | ||
|
||
def expected(A: T.Buffer[1, "int32"]): | ||
T.evaluate(0) | ||
|
||
|
||
class TestKeepOtherCallNodes(BaseBeforeAfter): | ||
"""Expressions containing other CallNodes are not removed""" | ||
|
||
def before(A: T.Buffer[1, "int32"], n: T.int32): | ||
A[0] = T.shift_left(n, 1, dtype="int32") | ||
|
||
expected = before | ||
|
||
|
||
class TestRemoveLetUndef(BaseBeforeAfter): | ||
"""Remove a store whose value is bound to T.undef()""" | ||
|
||
def before(A: T.Buffer[1, "int32"]): | ||
val = T.undef(dtype="int32") | ||
A[0] = val | ||
|
||
def expected(A: T.Buffer[1, "int32"]): | ||
T.evaluate(0) | ||
|
||
|
||
class TestRaiseErrorForUndefAsStoreIndices(BaseBeforeAfter): | ||
"""Use of T.undef() as buffer indices is an error""" | ||
|
||
def before(A: T.Buffer[1, "int32"]): | ||
val = T.undef(dtype="int32") | ||
A[val] = 5 | ||
|
||
expected = TVMError | ||
|
||
|
||
class TestRaiseErrorForUndefAsLoadIndices(BaseBeforeAfter): | ||
"""Use of T.undef() as buffer indices is an error | ||
Even though this occurs as part of the BufferStore's value, the | ||
T.undef() may not appear in a buffer's indices. | ||
""" | ||
|
||
def before(A: T.Buffer[1, "int32"], B: T.Buffer[1, "int32"]): | ||
B[0] = A[T.undef(dtype="int32")] | ||
|
||
expected = TVMError | ||
|
||
|
||
if __name__ == "__main__": | ||
tvm.testing.main() |