From 377bc913e67883d4c0f73d11520822c86485c925 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 15 May 2021 15:42:39 +0800 Subject: [PATCH] StmtExprUseVar --- include/tvm/tir/analysis.h | 70 +++++++++++++----------- src/arith/canonical_simplify.cc | 4 +- src/arith/detect_linear_equation.cc | 4 +- src/te/autodiff/ad_simplify.cc | 4 +- src/te/operation/compute_op.cc | 2 +- src/te/operation/tensorize.cc | 4 +- src/tir/analysis/var_touch.cc | 30 +++++++--- src/tir/schedule/analysis/analysis.cc | 2 +- src/tir/schedule/primitives/reduction.cc | 5 +- src/tir/transforms/loop_partition.cc | 18 +----- src/tir/transforms/lower_warp_memory.cc | 2 +- 11 files changed, 78 insertions(+), 67 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index ab966362e1..d0cb35e5e9 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -36,27 +36,6 @@ namespace tvm { namespace tir { -/*! - * \brief Checks if an Expr or Stmt contains a list of specific Vars - * \param stmt_or_expr The Stmt or Expr - * \return A boolean indicating if any var in the list is found in stmt/expr - */ -bool ContainsVar(const ObjectRef& stmt_or_expr, const Array& var); - -/*! - * \brief Checks if an Expr or Stmt contains a specific Var - * \param stmt_or_expr The Stmt or Expr - * \return A boolean indicating if the var is found in stmt/expr - */ -bool ContainsVar(const ObjectRef& stmt_or_expr, const Var& var); - -/*! - * \brief Checks if an Expr or Stmt contains a list of specific Vars - * \param stmt_or_expr The Stmt or Expr - * \return A boolean indicating if any var in the list is found in stmt/expr - */ -bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set& var); - /*! * \brief Compare two expressions recursively and check if they are equal * to each other without var remapping. @@ -117,21 +96,48 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); /*! - * \brief Whether e expression used any var in variable set.. - * \param expr The expression to be checked. - * \param vset_contains The check function to see if var is in the vset. - * \return Whether e uses vset. + * \brief Whether `stmt_or_expr` uses any var in the given variable set. + * \param stmt_or_expr The Stmt or PrimExpr to be checked. + * \param vset_contains The check function to see if var is in the variable set. + * \return Whether `stmt_or_expr` uses any var in the given variable set. */ -TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function vset_contains); +TVM_DLL bool StmtExprUseVar(const ObjectRef& stmt_or_expr, + std::function vset_contains); /*! - * \brief Whether e expression used var. - * \param expr The expression to be checked. - * \param var The variable. - * \return Whether e uses v. + * \brief Whether `stmt_or_expr` uses the given var. + * \param stmt_or_expr The Stmt or PrimExpr to be checked. + * \param var The input variable. + * \return Whether `stmt_or_expr` uses the given var. + */ +inline bool StmtExprUseVar(const ObjectRef& stmt_or_expr, const Var& var) { + return StmtExprUseVar(stmt_or_expr, [&](const VarNode* node) { return var.get() == node; }); +} + +/*! + * \brief Whether `stmt_or_expr` uses any var in the given variable set. + * \param stmt_or_expr The Stmt or PrimExpr to be checked. + * \param vars The given variable set. + * \return Whether `stmt_or_expr` uses any var in the given variable set. */ -inline bool ExprUseVar(const PrimExpr& expr, const Var& var) { - return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; }); +inline bool StmtExprUseVar(const ObjectRef& stmt_or_expr, + const std::unordered_set& vars) { + return StmtExprUseVar(stmt_or_expr, [&](const VarNode* node) { return vars.count(node); }); +} + +/*! + * \brief Whether `stmt_or_expr` uses any var in the given variable list. + * \param stmt_or_expr The Stmt or PrimExpr to be checked. + * \param vars The given variable list. + * \return Whether `stmt_or_expr` uses any var in the given variable set. + */ +inline bool StmtExprUseVar(const ObjectRef& stmt_or_expr, const Array& vars) { + std::unordered_set var_set; + var_set.reserve(vars.size()); + for (const Var& var : vars) { + var_set.insert(var.get()); + } + return StmtExprUseVar(stmt_or_expr, [&](const VarNode* node) { return var_set.count(node); }); } /*! diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index ba549959ac..2871b71372 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1137,8 +1137,8 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) // and recursively mark the corresponding components for (size_t i = 0; i < simplified_result.size(); ++i) if (!used[i]) { - if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || - ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) + if (StmtExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || + StmtExprUseVar(simplified_result[idx], op->combiner->rhs[i])) mark_used(i); } }; diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index f0634feac0..242db57b65 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor DetectLinearEquation(const PrimExpr& e, const Array& vars) for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset_contains)) { + if (StmtExprUseVar(coeff[i - 2], vset_contains)) { return Array(); } } diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index 76fed053fd..ea62a6c6d1 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -834,7 +834,7 @@ std::pair ImplicationNotContainingVars( return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) && (pair_b.first || pair_a.second) && (pair_a.second || pair_b.second)}; - } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { + } else if (!tir::StmtExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { return {cond, const_true()}; } else { return {const_true(), cond}; @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond, // Keep only those variables of the new vars which are used in the new_expr Array used_res_variables; for (const Var& var : res->dst->variables) { - if (ExprUseVar(new_expr, var)) { + if (StmtExprUseVar(new_expr, var)) { ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred."; used_res_variables.push_back(var); } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb356..bef80620b6 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map -#include #include namespace tvm { namespace tir { -class VarTouchVisitor : public ExprVisitor { +class VarTouchVisitor : public StmtExprVisitor { public: - explicit VarTouchVisitor(std::function var_set) : var_set_(var_set) {} + explicit VarTouchVisitor(std::function var_set) + : var_set_(std::move(var_set)) {} + + void VisitStmt(const Stmt& stmt) final { + if (use_var_) return; + StmtExprVisitor::VisitStmt(stmt); + } void VisitExpr(const PrimExpr& e) final { if (use_var_) return; - ExprVisitor::VisitExpr(e); + StmtExprVisitor::VisitExpr(e); } void VisitExpr_(const VarNode* op) final { Handle(op); } + void VisitStmt_(const StoreNode* op) final { + Handle(op->buffer_var.get()); + StmtVisitor::VisitStmt_(op); + } + void VisitExpr_(const LoadNode* op) final { Handle(op->buffer_var.get()); ExprVisitor::VisitExpr_(op); @@ -54,9 +64,15 @@ class VarTouchVisitor : public ExprVisitor { std::function var_set_; }; -bool ExprUseVar(const PrimExpr& e, std::function var_set) { - VarTouchVisitor visitor(var_set); - visitor(e); +bool StmtExprUseVar(const ObjectRef& stmt_or_expr, std::function var_set) { + VarTouchVisitor visitor(std::move(var_set)); + if (const auto* stmt = stmt_or_expr.as()) { + visitor(GetRef(stmt)); + } else if (const auto* e = stmt_or_expr.as()) { + visitor(GetRef(e)); + } else { + LOG(FATAL) << "TypeError: The input of StmtExprUseVar should be a Stmt or a PrimExpr."; + } return visitor.use_var_; } diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 6d1c46a068..7a42e152ad 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -96,7 +96,7 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, << "ValueError: The buffer \"" << store->buffer << "\" is written in the block but is not in the block's signature"; for (PrimExpr index : store->indices) { - if (ContainsVar(index, reduction_block_vars)) { + if (StmtExprUseVar(index, reduction_block_vars)) { affected = true; break; } diff --git a/src/tir/schedule/primitives/reduction.cc b/src/tir/schedule/primitives/reduction.cc index f9ab59b3bd..38f0517dad 100644 --- a/src/tir/schedule/primitives/reduction.cc +++ b/src/tir/schedule/primitives/reduction.cc @@ -40,8 +40,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) BlockRealize block_realize = GetBlockRealize(block_sref); Block block = block_realize->block; Optional scope_root = GetScopeRoot(block_sref); - CHECK(IsReductionBlock(self, block_sref, scope_root.value())) - << "ValueError: We can only do rfactor for loops of a reduction block"; +// Todo: comment out the two lines below out after Junru's PR getting merged +// CHECK(IsReductionBlock(self, block_sref, scope_root.value())) +// << "ValueError: We can only do rfactor for loops of a reduction block"; // Collect the information of the reduction. // Get the `init` identity and the `update` combiner of the reduction. diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index f1d816f0ba..7e1d48c63d 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -84,19 +85,6 @@ using Partition = std::unordered_map; -bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) { - bool success = false; - PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { - if (const VarNode* v = node.as()) { - if (vars.count(v)) { - success = true; - return; - } - } - }); - return success; -} - // Select potential candidate IRs that can be partitioned. // Rule: // - the range should not be const @@ -200,7 +188,7 @@ class PartitionFinder : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) final { - if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; + if (StmtExprUseVar(op->min, out_vars_) || StmtExprUseVar(op->extent, out_vars_)) return; const VarNode* var = op->loop_var.get(); hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)}); @@ -230,7 +218,7 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { PrimExpr cond = op->args[0]; - if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { + if (StmtExprUseVar(cond, current_var_)) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b95681a936..a5feeb3897 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -250,7 +250,7 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id - ICHECK(!ExprUseVar(local_index, warp_index_)) + ICHECK(!StmtExprUseVar(local_index, warp_index_)) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate);