Skip to content

Commit

Permalink
StmtExprUseVar
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed May 15, 2021
1 parent d615387 commit 377bc91
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 67 deletions.
70 changes: 38 additions & 32 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,6 @@
namespace tvm {
namespace tir {

/*!
* \brief Checks if an Expr or Stmt contains a list of specific Vars
* \param stmt_or_expr The Stmt or Expr
* \return A boolean indicating if any var in the list is found in stmt/expr
*/
bool ContainsVar(const ObjectRef& stmt_or_expr, const Array<Var>& var);

/*!
* \brief Checks if an Expr or Stmt contains a specific Var
* \param stmt_or_expr The Stmt or Expr
* \return A boolean indicating if the var is found in stmt/expr
*/
bool ContainsVar(const ObjectRef& stmt_or_expr, const Var& var);

/*!
* \brief Checks if an Expr or Stmt contains a list of specific Vars
* \param stmt_or_expr The Stmt or Expr
* \return A boolean indicating if any var in the list is found in stmt/expr
*/
bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set<const VarNode*>& var);

/*!
* \brief Compare two expressions recursively and check if they are equal
* to each other without var remapping.
Expand Down Expand Up @@ -117,21 +96,48 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set..
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
* \brief Whether `stmt_or_expr` uses any var in the given variable set.
* \param stmt_or_expr The Stmt or PrimExpr to be checked.
* \param vset_contains The check function to see if var is in the variable set.
* \return Whether `stmt_or_expr` uses any var in the given variable set.
*/
TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
TVM_DLL bool StmtExprUseVar(const ObjectRef& stmt_or_expr,
std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Whether e expression used var.
* \param expr The expression to be checked.
* \param var The variable.
* \return Whether e uses v.
* \brief Whether `stmt_or_expr` uses the given var.
* \param stmt_or_expr The Stmt or PrimExpr to be checked.
* \param var The input variable.
* \return Whether `stmt_or_expr` uses the given var.
*/
inline bool StmtExprUseVar(const ObjectRef& stmt_or_expr, const Var& var) {
return StmtExprUseVar(stmt_or_expr, [&](const VarNode* node) { return var.get() == node; });
}

/*!
* \brief Whether `stmt_or_expr` uses any var in the given variable set.
* \param stmt_or_expr The Stmt or PrimExpr to be checked.
* \param vars The given variable set.
* \return Whether `stmt_or_expr` uses any var in the given variable set.
*/
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
inline bool StmtExprUseVar(const ObjectRef& stmt_or_expr,
const std::unordered_set<const VarNode*>& vars) {
return StmtExprUseVar(stmt_or_expr, [&](const VarNode* node) { return vars.count(node); });
}

/*!
* \brief Whether `stmt_or_expr` uses any var in the given variable list.
* \param stmt_or_expr The Stmt or PrimExpr to be checked.
* \param vars The given variable list.
* \return Whether `stmt_or_expr` uses any var in the given variable set.
*/
inline bool StmtExprUseVar(const ObjectRef& stmt_or_expr, const Array<Var>& vars) {
std::unordered_set<const VarNode*> var_set;
var_set.reserve(vars.size());
for (const Var& var : vars) {
var_set.insert(var.get());
}
return StmtExprUseVar(stmt_or_expr, [&](const VarNode* node) { return var_set.count(node); });
}

/*!
Expand Down
4 changes: 2 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,8 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
// and recursively mark the corresponding components
for (size_t i = 0; i < simplified_result.size(); ++i)
if (!used[i]) {
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
if (StmtExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
StmtExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
mark_used(i);
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/arith/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const
}
LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) {
if (StmtExprUseVar(e, var_)) {
fail_ = true;
return LinearEqEntry();
} else {
Expand Down Expand Up @@ -159,7 +159,7 @@ Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars)
for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset_contains)) {
if (StmtExprUseVar(coeff[i - 2], vset_contains)) {
return Array<PrimExpr>();
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
(pair_b.first || pair_a.second) &&
(pair_a.second || pair_b.second)};
} else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
} else if (!tir::StmtExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
return {cond, const_true()};
} else {
return {const_true(), cond};
Expand Down Expand Up @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
// Keep only those variables of the new vars which are used in the new_expr
Array<Var> used_res_variables;
for (const Var& var : res->dst->variables) {
if (ExprUseVar(new_expr, var)) {
if (StmtExprUseVar(new_expr, var)) {
ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
used_res_variables.push_back(var);
}
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::StmtExprUseVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition " << pred
<< " has a conflict with the reset condition";
}
Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage,
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::StmtExprUseVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize failed, split condition " << pred
<< " relies on var defined inside tensorize scope";
}
}
for (const PrimExpr& pred : n.init_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::StmtExprUseVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize failed, split condition " << pred
<< " relies on var defined inside tensorize scope";
}
Expand Down
30 changes: 23 additions & 7 deletions src/tir/analysis/var_touch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,33 @@
* \brief Implementation of simple passes
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace tir {

class VarTouchVisitor : public ExprVisitor {
class VarTouchVisitor : public StmtExprVisitor {
public:
explicit VarTouchVisitor(std::function<bool(const VarNode*)> var_set) : var_set_(var_set) {}
explicit VarTouchVisitor(std::function<bool(const VarNode*)> var_set)
: var_set_(std::move(var_set)) {}

void VisitStmt(const Stmt& stmt) final {
if (use_var_) return;
StmtExprVisitor::VisitStmt(stmt);
}

void VisitExpr(const PrimExpr& e) final {
if (use_var_) return;
ExprVisitor::VisitExpr(e);
StmtExprVisitor::VisitExpr(e);
}

void VisitExpr_(const VarNode* op) final { Handle(op); }

void VisitStmt_(const StoreNode* op) final {
Handle(op->buffer_var.get());
StmtVisitor::VisitStmt_(op);
}

void VisitExpr_(const LoadNode* op) final {
Handle(op->buffer_var.get());
ExprVisitor::VisitExpr_(op);
Expand All @@ -54,9 +64,15 @@ class VarTouchVisitor : public ExprVisitor {
std::function<bool(const VarNode*)> var_set_;
};

bool ExprUseVar(const PrimExpr& e, std::function<bool(const VarNode*)> var_set) {
VarTouchVisitor visitor(var_set);
visitor(e);
bool StmtExprUseVar(const ObjectRef& stmt_or_expr, std::function<bool(const VarNode*)> var_set) {
VarTouchVisitor visitor(std::move(var_set));
if (const auto* stmt = stmt_or_expr.as<StmtNode>()) {
visitor(GetRef<Stmt>(stmt));
} else if (const auto* e = stmt_or_expr.as<PrimExprNode>()) {
visitor(GetRef<PrimExpr>(e));
} else {
LOG(FATAL) << "TypeError: The input of StmtExprUseVar should be a Stmt or a PrimExpr.";
}
return visitor.use_var_;
}

Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
<< "ValueError: The buffer \"" << store->buffer
<< "\" is written in the block but is not in the block's signature";
for (PrimExpr index : store->indices) {
if (ContainsVar(index, reduction_block_vars)) {
if (StmtExprUseVar(index, reduction_block_vars)) {
affected = true;
break;
}
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/primitives/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis)
BlockRealize block_realize = GetBlockRealize(block_sref);
Block block = block_realize->block;
Optional<StmtSRef> scope_root = GetScopeRoot(block_sref);
CHECK(IsReductionBlock(self, block_sref, scope_root.value()))
<< "ValueError: We can only do rfactor for loops of a reduction block";
// Todo: comment out the two lines below out after Junru's PR getting merged
// CHECK(IsReductionBlock(self, block_sref, scope_root.value()))
// << "ValueError: We can only do rfactor for loops of a reduction block";

// Collect the information of the reduction.
// Get the `init` identity and the `update` combiner of the reduction.
Expand Down
18 changes: 3 additions & 15 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/bound.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
Expand Down Expand Up @@ -84,19 +85,6 @@ using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash, Par

using ExpressionSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;

bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars) {
bool success = false;
PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) {
if (const VarNode* v = node.as<VarNode>()) {
if (vars.count(v)) {
success = true;
return;
}
}
});
return success;
}

// Select potential candidate IRs that can be partitioned.
// Rule:
// - the range should not be const
Expand Down Expand Up @@ -200,7 +188,7 @@ class PartitionFinder : public StmtExprVisitor {
}

void VisitStmt_(const ForNode* op) final {
if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
if (StmtExprUseVar(op->min, out_vars_) || StmtExprUseVar(op->extent, out_vars_)) return;

const VarNode* var = op->loop_var.get();
hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)});
Expand Down Expand Up @@ -230,7 +218,7 @@ class PartitionFinder : public StmtExprVisitor {
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
PrimExpr cond = op->args[0];
if (ExprUseVars(cond, std::unordered_set<const VarNode*>({current_var_.get()}))) {
if (StmtExprUseVar(cond, current_var_)) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
// invariance: local index must do not contain warp id
ICHECK(!ExprUseVar(local_index, warp_index_))
ICHECK(!StmtExprUseVar(local_index, warp_index_))
<< "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index
<< " local_index=" << local_index;
PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate);
Expand Down

0 comments on commit 377bc91

Please sign in to comment.