Skip to content

Commit

Permalink
[RemoveStoreUndef] Added error checking for illegal T.undef() usage
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Aug 1, 2022
1 parent 6734181 commit 7311263
Showing 1 changed file with 109 additions and 22 deletions.
131 changes: 109 additions & 22 deletions src/tir/transforms/remove_store_undef.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,52 +32,139 @@
namespace tvm {
namespace tir {

class StoreUndefLocator : public StmtExprVisitor {
public:
static std::unordered_set<const BufferStoreNode*> Locate(Stmt stmt) {
StoreUndefLocator locator;
locator(std::move(stmt));
return locator.undef_stores_;
}

private:
StoreUndefLocator() = default;

void VisitStmt_(const BufferStoreNode* op) final {
bool stash_undef = false;
std::swap(has_undef_, stash_undef);
StmtExprVisitor::VisitExpr(op->value);
std::swap(has_undef_, stash_undef);
if (stash_undef) {
undef_stores_.insert(op);
}
}

void VisitExpr_(const BufferLoadNode* op) final {
// This function left deliberately empty. builtin::undef()
// shouldn't occur in the indices of BufferLoad. Avoiding
// visiting the indices catches the builtin::undef in
// ValidateAllUndefRemoved.
}

void VisitStmt_(const LetStmtNode* op) final {
bool stash_undef = false;
std::swap(has_undef_, stash_undef);
StmtExprVisitor::VisitExpr(op->value);
std::swap(has_undef_, stash_undef);
if (stash_undef) {
var_bindings_with_undef_.insert(op->var.get());
}

StmtExprVisitor::VisitStmt(op->body);
}

void VisitExpr_(const VarNode* op) final {
if (var_bindings_with_undef_.count(op)) {
has_undef_ = true;
}
}

void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::undef())) {
has_undef_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}

bool has_undef_{false};

std::unordered_set<const VarNode*> var_bindings_with_undef_;
std::unordered_set<const BufferStoreNode*> undef_stores_;
};

// Remove any BufferStores whose value depends on T.undef
class StoreUndefRemover : public StmtExprMutator {
public:
static Stmt Apply(Stmt stmt) {
StoreUndefRemover visitor;
return visitor(std::move(stmt));
auto to_remove = StoreUndefLocator::Locate(stmt);
StoreUndefRemover mutator(std::move(to_remove));
return mutator(std::move(stmt));
}

private:
using Parent = StmtExprMutator;
using Parent::Parent;
using Parent::VisitStmt;
using Parent::VisitStmt_;

StoreUndefRemover(std::unordered_set<const BufferStoreNode*> to_remove) : to_remove_(to_remove) {}

Stmt VisitStmt_(const BufferStoreNode* op) final {
has_undef = false;
Parent::VisitExpr(op->value);
if (has_undef) {
if (to_remove_.count(op)) {
return Evaluate(0);
} else {
return GetRef<Stmt>(op);
return Parent::VisitStmt_(op);
}
}

PrimExpr VisitExpr_(const CallNode* op) final {
std::unordered_set<const BufferStoreNode*> to_remove_;
};

// Remove any BufferStores whose value depends on T.undef
class ContainsUndefChecker : public StmtExprVisitor {
public:
static bool Check(const Stmt& stmt) {
ContainsUndefChecker checker;
checker(stmt);
return checker.contains_undef;
}

private:
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::undef())) {
has_undef = true;
contains_undef = true;
}
return GetRef<PrimExpr>(op);
StmtExprVisitor::VisitExpr_(op);
}

bool has_undef{false};
bool contains_undef{false};
};

namespace transform {
Pass RemoveStoreUndefInternal() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = StoreUndefRemover::Apply(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveStoreUndefInternal", {});
}

Pass ValidateAllUndefRemoved() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool contains_undef = ContainsUndefChecker::Check(f->body);
ICHECK(!contains_undef) << "Expected removal of BufferStore containing builtin::undef() "
<< "to remove all instances of builtin::undef(). "
<< "Instead, result was"
<< "\n"
<< f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.ValidateAllUndefRemoved", {});
}

Pass RemoveStoreUndef() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = StoreUndefRemover::Apply(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveStoreUndef", {});
}
Pass RemoveStoreUndef() {
return Sequential({RemoveStoreUndefInternal(), RemoveNoOp(), ValidateAllUndefRemoved()},
"tir.RemoveStoreUndef");
}

TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef);
TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef);

} // namespace transform

Expand Down

0 comments on commit 7311263

Please sign in to comment.