From 47c8e47a489f71d3e5b8d9991631a9019dda69d5 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 24 May 2021 17:08:58 -0700 Subject: [PATCH] [TensorIR][M2a] Verification of cached flags (#8114) * [TensorIR][M2a] Verification of cached flags Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin * Address comments * Update src/tir/schedule/analysis/verify.cc Co-authored-by: Cody Yu Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Cody Yu --- include/tvm/arith/int_set.h | 48 +- include/tvm/tir/schedule/block_scope.h | 4 +- include/tvm/tir/schedule/schedule.h | 4 +- include/tvm/tir/schedule/state.h | 15 +- include/tvm/tir/stmt_functor.h | 8 + python/tvm/arith/__init__.py | 2 +- python/tvm/arith/int_set.py | 60 ++ python/tvm/tir/schedule/schedule.py | 4 +- python/tvm/tir/schedule/state.py | 45 +- src/arith/int_set.cc | 139 ++++ src/printer/text_printer.h | 2 + src/printer/tvmscript_printer.cc | 12 +- src/tir/ir/stmt_functor.cc | 11 + src/tir/schedule/analysis.h | 44 ++ src/tir/schedule/analysis/analysis.cc | 68 ++ src/tir/schedule/analysis/verify.cc | 97 +++ src/tir/schedule/concrete_schedule.h | 5 +- src/tir/schedule/state.cc | 274 ++++++- src/tir/schedule/utils.h | 60 ++ src/tir/transforms/compact_buffer_region.cc | 11 +- tests/python/unittest/test_arith_intset.py | 106 +++ ...pe.py => test_tir_schedule_block_scope.py} | 0 .../unittest/test_tir_schedule_state.py | 3 +- .../test_tir_schedule_state_cached_flags.py | 669 ++++++++++++++++++ 24 files changed, 1615 insertions(+), 76 deletions(-) rename tests/python/unittest/{test_tir_block_scope.py => test_tir_schedule_block_scope.py} (100%) create mode 100644 tests/python/unittest/test_tir_schedule_state_cached_flags.py diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 515392db8612..b9e81c0a5533 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -36,6 +36,8 @@ using tir::IterVar; using tir::Var; using tir::VarNode; +class Analyzer; + //----------------------------------------------- // Integer set data structure. // @@ -190,6 +192,14 @@ IntSet EvalSet(IntSet s, const std::unordered_map& dom_m * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(Range r, const std::unordered_map& dom_map); +/*! + * \brief Same as EvalSet, but takes Array + * + * \param region The range to be evaluated. + * \param dom_map The domain of each variable. + * \return An array of integer sets that can cover all the possible values. + */ +Array EvalSet(const Array& region, const Map& dom_map); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; /*! @@ -204,12 +214,33 @@ ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, const std::unordered_map& dom_map); /*! - * \brief Create an union set of all sets - * \param sets The sets to be unioned + * \brief Create a union set of all sets, possibly relaxed + * \param sets The sets to be combined * \return the set after union */ IntSet Union(const Array& sets); +/*! + * \brief The union of N-dimensional integer sets + * \param nd_int_sets A list of N-dimensional integer sets + * \return An N-dimensional integer set as the result of union + */ +Array UnionRegion(const Array>& nd_int_sets); + +/*! + * \brief Create a lower-bound of union set, where some of the segments may be dropped + * \param sets The sets to be combined + * \return the set after union + */ +IntSet UnionLowerBound(const Array& sets); + +/*! + * \brief The union of N-dimensional integer sets + * \param nd_int_sets A list of N-dimensional integer sets + * \return An N-dimensional integer set as the result of union + */ +Array UnionRegionLowerBound(const Array>& nd_int_sets); + /*! * \brief Create an union set of all sets * \param sets The sets to be intersected @@ -217,6 +248,19 @@ IntSet Union(const Array& sets); */ IntSet Intersect(const Array& sets); +/*! + * \brief Analyze the region with affine map, given the domain of variables and their predicate + * \param region The region to be analyzed + * \param var_dom The ranges of the variables + * \param predicate The predicate for the affine map + * \param analyzer The analyzer used + * \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis + */ +TVM_DLL Optional> EstimateRegionLowerBound(const Array& region, + const Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_INT_SET_H_ diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h index 49d5e7f2c323..fb08583b7771 100644 --- a/include/tvm/tir/schedule/block_scope.h +++ b/include/tvm/tir/schedule/block_scope.h @@ -221,7 +221,9 @@ class BlockScopeNode : public Object { * equivalent to of a stage pipeline. Under the following conditions: * * 1) The region cover property holds for every of its child blocks - * 2) No write-after-read dependency + * 2) No write-after-read dependency or opaque dependency, only read-after-write and + * write-after-write are allowed + * 3) All the statements in the scope are schedulable statements, i.e. Block and For */ bool stage_pipeline{false}; diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index c0af375ad72d..b85fdec8cba9 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -213,9 +213,7 @@ class Schedule : public runtime::ObjectRef { * \sa ScheduleDebugMask * \note The checks performed includes: * 1) VerifySRefTree - * 2) VerifyAffineBinding - * 3) VerifyRegionCover - * 4) VerifyStagePipeline + * 2) VerifyCachedFlags */ TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 12b6fc18dc21..83ac7150543f 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -64,15 +64,11 @@ struct BlockInfo { * \brief The bitmask of the debug flag in the ScheduleStateNode. * \sa ScheduleStateNode */ -enum class ScheduleDebugMask : int32_t { +enum ScheduleDebugMask : uint32_t { /*! \brief Verify the correctness of the sref tree */ kVerifySRefTree = 1, - /*! \brief Verify the correctness of affine_binding */ - kVerifyAffineBinding = 2, - /*! \brief Verify the correctness of region_cover */ - kVerifyRegionCover = 4, - /*! \brief Verify the correctness of stage_pipeline */ - kVerifyStagePipeline = 8, + /*! \brief Verify the correctness of affine_binding, region_cover and stage_pipeline */ + kVerifyCachedFlags = 2, }; /*! @@ -135,9 +131,8 @@ class ScheduleStateNode : public Object { /*! * \brief Trigger the verification according to the `debug_mode` bitmask. * 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree. - * 2) If the bitmask `kVerifyAffineBinding` is on, verify the correctness of `affine_binding` - * 3) If the bitmask `kVerifyRegionCover` is on, verify the correctness of `region_cover` - * 4) If the bitmask `kVerifyStagePipeline` is on, verify the correctness of `stage_pipeline` + * 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`, + * `region_cover` and `stage_pipeline` */ TVM_DLL void DebugVerify() const; diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index c1c618f0c22f..8273f9912a57 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -352,6 +352,14 @@ TVM_DLL Stmt Substitute(Stmt stmt, std::function(const Var& v */ TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(const Var& var)> vmap); +/*! + * \brief Substitute the var specified by vmap. + * \param region The object whose vars are to be substituted + * \param vmap The map of new values. + * \return The result. + */ +TVM_DLL Array Substitute(const Array& region, const Map& vmap); + /*! * \brief Sugar for substitute via a given map. * \param input The input to be updated. diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index a4cdb9839b22..d1e4431a2e0e 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -16,7 +16,7 @@ # under the License. """Integer bound analysis, simplification and pattern detection.""" -from .int_set import IntSet, IntervalSet +from .int_set import IntSet, IntervalSet, estimate_region_lower_bound from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index 255dbfda685b..b5f2100b7c7d 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -79,3 +79,63 @@ class IntervalSet(IntSet): def __init__(self, min_value, max_value): self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value) + + +def estimate_region_lower_bound(region, var_dom, predicate): + """Analyze the region with affine map, given the domain of variables and their predicate + + Parameters + ---------- + region : List[Range] + The region to be analyzed. + + var_dom : Dict[Var, Range] + The ranges of the variables + + predicate : PrimExpr + The predicate for the affine map + + Returns + ---------- + region_int_set : Optional[List[IntSet]] + None if the detection fails, or an array of IntSets as the result of analysis + """ + return _ffi_api.EstimateRegionLowerBound(region, var_dom, predicate) + + +def pos_inf(): + """Returns the symbolic positive infinity + + Returns + ---------- + pos_inf : Var + A symbolic var that indicates positive infinity + """ + return _ffi_api.PosInf() + + +def neg_inf(): + """Returns the symbolic positive infinity + + Returns + ---------- + neg_inf : Var + A symbolic var that indicates positive infinity + """ + return _ffi_api.NegInf() + + +def union_lower_bound(sets): + """Create a lower-bound of union set, where some of the segments may be dropped + + Parameters + ---------- + sets : List[IntSet] + The sets to be combined + + Returns + ---------- + union_lower_bound : List[IntSet] + An N-dimensional integer set, the lower bound of the union + """ + return _ffi_api.UnionLowerBound(sets) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 0dbc66ae4ac4..f207fa274212 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -76,9 +76,7 @@ def __init__( ---------- The checks performed includes: 1) VerifySRefTree - 2) VerifyAffineBinding - 3) VerifyRegionCover - 4) VerifyStagePipeline + 2) VerifyCachedFlags """ if isinstance(debug_mode, bool): if debug_mode: diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index 180fede228e5..5a8e6cabe1e8 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """This file defines ScheduleState, the core data structure of TensorIR scheduling.""" +from collections import namedtuple from enum import IntEnum from typing import Dict, Optional, Union @@ -26,6 +27,8 @@ from . import _ffi_api_schedule from .block_scope import BlockScope, StmtSRef +CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover", "stage_pipeline"]) + class ScheduleDebugMask(IntEnum): """The bitmask of the `debug_mode` flag in the ScheduleState class. @@ -38,18 +41,12 @@ class ScheduleDebugMask(IntEnum): ---------- VERIFY_SREF_TREE : int = 1 Verify the correctness of the sref tree - VERIFY_AFFINE_BINDING : int = 2 - Verify the correctness of affine_binding - VERIFY_REGION_COVER : int = 4 - Verify the correctness of region_cover - VERIFY_STAGE_PIPELINE: int = 8 - Verify the correctness of stage_pipeline + VERIFY_CACHED_FLAGS : int = 2 + Verify the correctness of affine_binding, region_cover and stage_pipeline """ VERIFY_SREF_TREE = 1 - VERIFY_AFFINE_BINDING = 2 - VERIFY_REGION_COVER = 4 - VERIFY_STAGE_PIPELINE = 8 + VERIFY_CACHED_FLAGS = 2 @register_object("tir.ScheduleState") @@ -140,6 +137,36 @@ def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: self, block_sref ) + def _get_cached_flags(self, block_sref: StmtSRef) -> CachedFlags: + """Get the cached flags of the corresponding block + + Parameters + ---------- + block_sref : StmtSRef + The block sref to be retrieved + + Returns + ------- + flags : CachedFlags + Three flags: affine_binding, region_cover, stage_pipeline + + Note + ------- + It is an API intended for internal testing use. + """ + ( + affine_binding, + region_cover, + stage_pipeline, + ) = _ffi_api_schedule.ScheduleStateGetCachedFlags( # pylint: disable=no-member + self, block_sref + ) + return CachedFlags( + affine_binding=bool(affine_binding.value), + region_cover=bool(region_cover.value), + stage_pipeline=bool(stage_pipeline.value), + ) + def replace( self, src_sref: StmtSRef, diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 6490f67e1b1a..7000de96dc99 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -22,6 +22,7 @@ * \brief The integer set functions */ #include +#include #include #include #include @@ -635,6 +636,77 @@ IntSet Union(const Array& sets) { return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } +Array UnionRegion(const Array>& nd_int_sets) { + if (nd_int_sets.empty()) { + return {}; + } + int n = nd_int_sets.size(); + int ndim = nd_int_sets[0].size(); + Array result; + result.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + Array candidates; + candidates.reserve(n); + for (int j = 0; j < n; ++j) { + candidates.push_back(nd_int_sets[j][i]); + } + result.push_back(Union(candidates)); + } + return result; +} + +IntSet UnionLowerBound(const Array& sets) { + if (sets.size() == 0) return IntSet::Nothing(); + if (sets.size() == 1) return sets[0]; + Analyzer analyzer; + bool is_first_interval = true; + PrimExpr min_inclusive{nullptr}; + PrimExpr max_inclusive(nullptr); + for (const IntSet& int_set : sets) { + if (const auto* interval_set = int_set.as()) { + PrimExpr new_min_inclusive = interval_set->min_value; + PrimExpr new_max_inclusive = interval_set->max_value; + if (is_first_interval) { + is_first_interval = false; + min_inclusive = std::move(new_min_inclusive); + max_inclusive = std::move(new_max_inclusive); + continue; + } + bool bound_1 = is_neg_inf(new_min_inclusive) || is_pos_inf(max_inclusive) || + analyzer.CanProve(new_min_inclusive <= max_inclusive + 1); + bool bound_2 = is_neg_inf(min_inclusive) || is_pos_inf(new_max_inclusive) || + analyzer.CanProve(min_inclusive <= new_max_inclusive + 1); + if (bound_1 && bound_2) { + min_inclusive = min(min_inclusive, new_min_inclusive); + max_inclusive = max(max_inclusive, new_max_inclusive); + } + } + } + if (is_first_interval) { + return IntSet::Nothing(); + } + return IntSet::Interval(min_inclusive, max_inclusive); +} + +Array UnionRegionLowerBound(const Array>& nd_int_sets) { + if (nd_int_sets.empty()) { + return {}; + } + int n = nd_int_sets.size(); + int ndim = nd_int_sets[0].size(); + Array result; + result.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + Array candidates; + candidates.reserve(n); + for (int j = 0; j < n; ++j) { + candidates.push_back(nd_int_sets[j][i]); + } + result.push_back(UnionLowerBound(candidates)); + } + return result; +} + IntSet Intersect(const Array& sets) { if (sets.size() == 0) return IntSet::Nothing(); if (sets.size() == 1) return sets[0]; @@ -694,6 +766,18 @@ IntSet EvalSet(Range r, const std::unordered_map& dom_ma return EvalSet(r, ConvertDomMap(dom_map)); } +Array EvalSet(const Array& region, const Map& dom_map) { + Analyzer ana; + IntervalSetEvaluator m(&ana, dom_map); + Array result; + result.reserve(region.size()); + for (const Range& r : region) { + PrimExpr sum = r->min + (r->extent - 1); + result.push_back(m.Eval(IntervalSet(r->min, ana.Simplify(sum)))); + } + return result; +} + IntSet EvalSet(IntSet s, const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); @@ -731,6 +815,50 @@ IntSet EvalSet(Range r, const Map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } +Optional> EstimateRegionLowerBound(const Array& region, + const Map& var_dom, + const PrimExpr& predicate, + arith::Analyzer* analyzer) { + int ndim = region.size(); + Array iter_sum_exprs{nullptr}; + { + Array affine_indices; + affine_indices.reserve(ndim); + for (const Range& range : region) { + affine_indices.push_back(range->min); + } + iter_sum_exprs = arith::DetectIterMap( + /*indices=*/affine_indices, /*input_iters=*/var_dom, + /*predicate=*/predicate, /*require_bijective=*/false, analyzer); + } + if (iter_sum_exprs.empty()) { + return NullOpt; + } + ICHECK_EQ(iter_sum_exprs.size(), ndim); + Array result; + result.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + const arith::IterSumExpr& sum_expr = iter_sum_exprs[i]; + const Range& range = region[i]; + if (sum_expr->args.empty()) { + result.push_back(arith::IntSet::Interval(sum_expr->base, sum_expr->base + range->extent)); + continue; + } + ICHECK_EQ(sum_expr->args.size(), 1); + const arith::IterSplitExpr& split = sum_expr->args[0]; + if (!analyzer->CanProve(range->extent >= split->scale)) { + return NullOpt; + } + const PrimExpr& base = sum_expr->base; + // IterSplitExpr: (source // lower_factor) % extent * scale + // where `(source // lower_factor) % extent` is within [0, extent - 1] + // Therefore, the range of `region[i]->min` is `base + [0, (extent - 1) * scale]` + result.push_back(arith::IntSet::Interval( + base, split->extent * split->scale + base + (range->extent - split->scale) - 1)); + } + return result; +} + TVM_REGISTER_NODE_TYPE(IntervalSetNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -754,5 +882,16 @@ TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing) TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything); +TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound") + .set_body_typed([](Array region, Map var_dom, + PrimExpr predicate) -> Optional> { + Analyzer analyzer; + return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); + }); + +TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; }); +TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; }); +TVM_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound); + } // namespace arith } // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index e2a61e1d940f..52ab701008c7 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -359,6 +359,8 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintBody(const Stmt& body, bool indent = true); }; +String AsTVMScript(const ObjectRef& mod, bool show_meta = false); + } // namespace tir } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index d5fb3d1e63da..fa92b8f04edc 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1116,12 +1116,12 @@ Doc TVMScriptPrinter::PrintLoopStack() { return res; } -TVM_REGISTER_GLOBAL("script.AsTVMScript") - .set_body_typed([](const ObjectRef& functions, - bool show_meta) { - ICHECK(functions.as() != nullptr || functions.as() != nullptr); - return "@tvm.script.tir\n" + TVMScriptPrinter(show_meta).Print(functions).str() + "\n"; - }); +String AsTVMScript(const ObjectRef& mod, bool show_meta) { + ICHECK(mod->IsInstance() || mod->IsInstance()); + return "@tvm.script.tir\n" + TVMScriptPrinter(show_meta).Print(mod).str() + "\n"; +} + +TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 07574e4fb2f1..d60ec72a7589 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -688,6 +688,17 @@ PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> return IRSubstitute(vmap)(std::move(expr)); } +Array Substitute(const Array& region, const Map& vmap) { + Array result; + result.reserve(region.size()); + for (const Range& range : region) { + PrimExpr min = Substitute(range->min, vmap); + PrimExpr extent = Substitute(range->extent, vmap); + result.push_back(Range::FromMinExtent(std::move(min), std::move(extent))); + } + return result; +} + void PreOrderVisit(const ObjectRef& stmt_or_expr, const std::function& fvisit) { class PreOrderVisitor : public StmtExprVisitor { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index b21139d37e1f..8d52a621b900 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -31,6 +31,50 @@ namespace tir { * \throw An exception will be thrown if the sref tree is not valid */ void VerifySRefTree(const ScheduleState& self); +/*! + * \brief Verify the cached flags in the schedule state, including: + * - affine_binding + * - region_cover + * - stage_pipeline + * \param self The schedule state to be verified + * \throw An exception will be thrown if some srefs are not valid + */ +void VerifyCachedFlags(const ScheduleState& self); + +/******** Binding ********/ + +/*! + * \brief Verify if the block binding in a specific BlockRealize is an affine binding. + * The binding can be represented as an injective affine map from the loop iterators. + * \param realize The BlockRealize to be analyzed + * \param loop_var_ranges The ranges of the loop variables + * \param analyzer The analyzer + * \return A boolean flag indicating if the binding is affine + */ +bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, + arith::Analyzer* analyzer); + +/*! + * \brief Extract the ranges of loop variables in a path of the sref tree + * \param low_inclusive The lowest node in the path + * \param high_exclusive The highest node in the path, defaults to the scope root if not specified + * \param extra_relax_scope If the scope is not global, the method will look beyond the limit and + * retrieve extra domains. For example, + * - if the storage scope is warp, it will look upwards for threadIdx.x + * - if the storage scope is shared, it will look for threadIdx.x/y/z + * \return The loop domain + */ +Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, + const Optional& high_exclusive = NullOpt, + const runtime::StorageScope& extra_relax_scope = // + runtime::StorageScope{runtime::StorageRank::kGlobal, ""}); + +/*! + * \brief Returns the block var binding + * \param realize The BlockRealize to be analyzed + * \return The block var binding + */ +Map GetBindings(const BlockRealize& realize); /******** Block-loop relation ********/ /*! diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 08e7ac749e0f..e4b767bc40ad 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -21,6 +21,74 @@ namespace tvm { namespace tir { +/******** Binding ********/ + +bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, + arith::Analyzer* analyzer) { + if (loop_var_ranges.empty()) { + return true; + } + Array results = arith::DetectIterMap( + /*indices=*/realize->iter_values, + /*input_iters=*/loop_var_ranges, + /*predicate=*/realize->predicate, + /*require_bijective=*/false, + /*analyzer=*/analyzer); + if (results.empty()) { + return false; + } + for (const arith::IterSumExpr& sum_expr : results) { + const Array& args = sum_expr->args; + if (!args.empty() && !is_one(args[0]->scale)) { + return false; + } + } + return true; +} + +Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, + const Optional& high_exclusive, + const runtime::StorageScope& extra_relax_scope) { + Map result; + const StmtSRefNode* p = low_inclusive.get(); + const StmtSRefNode* limit = static_cast(high_exclusive.get()); + for (; p != limit; p = p->parent) { + const ForNode* loop = p->StmtAs(); + if (loop == nullptr) { + break; + } + result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + if (extra_relax_scope.rank != runtime::StorageRank::kGlobal) { + for (; p; p = p->parent) { + if (const ForNode* loop = p->StmtAs()) { + if (loop->kind == ForKind::kThreadBinding) { + const String& thread_tag = loop->thread_binding.value()->thread_tag; + if (CanRelaxStorageUndereThread(extra_relax_scope, + runtime::ThreadScope::Create(thread_tag))) { + result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + } + } + } + } + return result; +} + +Map GetBindings(const BlockRealize& realize) { + const BlockNode* block = realize->block.get(); + const Array& all_lhs = block->iter_vars; + const Array& all_rhs = realize->iter_values; + ICHECK_EQ(all_lhs.size(), all_rhs.size()); + Map result; + for (int i = 0, n = all_lhs.size(); i < n; ++i) { + const IterVar& lhs = all_lhs[i]; + const PrimExpr& rhs = all_rhs[i]; + result.Set(lhs->var, rhs); + } + return result; +} + /******** Block-loop relation ********/ Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name) { diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index edb62b54cd1b..e9ee7227f6fb 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -142,5 +142,102 @@ class SRefTreeVerifier : public StmtVisitor { void VerifySRefTree(const ScheduleState& self) { SRefTreeVerifier::Verify(self.get()); } +void VerifyCachedFlags(const ScheduleState& self) { + std::vector block_info_not_found; + std::vector> block_info_wrong_affine_binding; + std::vector> block_info_wrong_region_cover; + std::vector> block_info_wrong_stage_pipeline; + + ScheduleState new_state(self->mod); + for (const auto& kv : new_state->stmt2ref) { + const StmtNode* stmt = kv.first; + const StmtSRef& new_sref = kv.second; + if (stmt->IsInstance() || !self->stmt2ref.count(stmt)) { + continue; + } + const BlockInfo& new_block_info = new_state->block_info.at(new_sref); + const StmtSRef& old_sref = self->stmt2ref.at(stmt); + if (!self->block_info.count(old_sref)) { + block_info_not_found.push_back(new_sref); + continue; + } + const BlockInfo& old_block_info = self->block_info.at(old_sref); + if (new_block_info.affine_binding != old_block_info.affine_binding) { + block_info_wrong_affine_binding.emplace_back(new_sref, // + new_block_info.affine_binding, + old_block_info.affine_binding); + } + if (new_block_info.region_cover != old_block_info.region_cover) { + block_info_wrong_region_cover.emplace_back(new_sref, // + new_block_info.region_cover, + old_block_info.region_cover); + } + if (new_block_info.scope->stage_pipeline != old_block_info.scope->stage_pipeline) { + block_info_wrong_stage_pipeline.emplace_back(new_sref, // + new_block_info.scope->stage_pipeline, + old_block_info.scope->stage_pipeline); + } + } + + bool has_not_found = !block_info_not_found.empty(); + bool has_wrong_affine_binding = !block_info_wrong_affine_binding.empty(); + bool has_wrong_region_cover = !block_info_wrong_region_cover.empty(); + bool has_wrong_stage_pipeline = !block_info_wrong_stage_pipeline.empty(); + if (!(has_not_found || has_wrong_affine_binding || has_wrong_region_cover || + has_wrong_stage_pipeline)) { + return; + } + std::ostringstream os; + if (has_not_found) { + os << "- BlockInfo not found:"; + for (const StmtSRef& block_sref : block_info_not_found) { + const auto* block = block_sref->StmtAs(); + ICHECK(block); + os << " " << block->name_hint; + } + os << std::endl; + } + if (has_wrong_affine_binding) { + os << "- Wrong affine_binding: "; + for (const std::tuple& record : block_info_wrong_affine_binding) { + const StmtSRef& block_sref = std::get<0>(record); + bool expected = std::get<1>(record); + bool actual = std::get<2>(record); + const auto* block = block_sref->StmtAs(); + ICHECK(block); + os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; + } + os << std::endl; + } + if (has_wrong_region_cover) { + os << "- Wrong region_cover: "; + for (const std::tuple& record : block_info_wrong_region_cover) { + const StmtSRef& block_sref = std::get<0>(record); + bool expected = std::get<1>(record); + bool actual = std::get<2>(record); + const auto* block = block_sref->StmtAs(); + ICHECK(block); + os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; + } + os << std::endl; + } + if (has_wrong_stage_pipeline) { + os << "- Wrong stage_pipeline: "; + for (const std::tuple& record : block_info_wrong_stage_pipeline) { + const StmtSRef& block_sref = std::get<0>(record); + bool expected = std::get<1>(record); + bool actual = std::get<2>(record); + const auto* block = block_sref->StmtAs(); + ICHECK(block); + os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; + } + os << std::endl; + } + LOG(FATAL) << "Schedule verification failed. The IR is:\n" + << AsTVMScript(self->mod) << "\nThe errors are:\n" + << os.str(); + throw; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 947d52fba5bf..39eab1159db9 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -19,9 +19,6 @@ #ifndef TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ #define TVM_TIR_SCHEDULE_CONCRETE_SCHEDULE_H_ -#include -#include - #include #include @@ -208,7 +205,7 @@ template inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { T rv; this->symbol_table_.Set(rv, sref); - return rv; + return std::move(rv); } inline ExprRV ConcreteScheduleNode::CreateRV(const PrimExpr& expr) { diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index d1b899b05439..ca61dfea2768 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -26,6 +26,87 @@ using SMap = std::unordered_map; /**************** Utility functions ****************/ +/*! + * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) + * Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added + * to the result. + * \param region The buffer region to be analyzed + * \param dom_low_inclusive The lowest node in the sref tree path + * \param dom_high_exclusive The highest node in the sref tree path + * \return An n-dimensional integer set + */ +Array AnalyzeRegionUpperBound(const BufferRegion& region, + const StmtSRef& dom_low_inclusive, + const StmtSRef& dom_high_exclusive) { + return arith::EvalSet( + region->region, + AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/dom_low_inclusive, + /*high_exclusive=*/dom_high_exclusive, + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)))); +} + +/*! + * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) + * Some subregion may be discarded during the lower-bound analysis. + * \param realize The block realize that touches the buffer region + * \param region The buffer region to be analyzed + * \param dom_low_inclusive The lowest node in the sref tree path + * \param dom_high_exclusive The highest node in the sref tree path + * \param analyzer The analyzer + * \return An n-dimensional integer set + */ +Array AnalyzeRegionLowerBound(const BlockRealize& realize, + const BufferRegion& region, + const StmtSRef& dom_low_inclusive, + const StmtSRef& dom_high_exclusive, + arith::Analyzer* analyzer) { + if (Optional> result = EstimateRegionLowerBound( + /*region=*/region->region, + /*var_dom=*/ + LoopDomainOfSRefTreePath( + /*low_inclusive=*/dom_low_inclusive, + /*high_exclusive=*/dom_high_exclusive, + /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer->scope)), + /*predicate=*/realize->predicate, /*analyzer=*/analyzer)) { + return result.value(); + } + return Array(region->buffer->shape.size(), arith::IntSet::Nothing()); +} + +/*! + * \brief Checks if the produced region can cover the consumed region + * \param buffer_shape The shape of the buffer + * \param produced_region The N-dimensional produced region + * \param consumed_region The N-dimensional consumed region + * \param analyzer The analyzer + * \return A boolean indicating if the produced region could cover the consumed region + */ +bool ProducerCoversConsumer(const Array& buffer_shape, + const Array& produced_region, + const Array& consumed_region, + arith::Analyzer* analyzer) { + ICHECK_EQ(buffer_shape.size(), consumed_region.size()); + ICHECK_EQ(produced_region.size(), consumed_region.size()); + int ndim = produced_region.size(); + for (int i = 0; i < ndim; ++i) { + Range buffer_size = Range::FromMinExtent(0, buffer_shape[i]); + if (produced_region[i].IsNothing()) { + return false; + } + Range produced = produced_region[i].CoverRange(buffer_size); + Range consumed = consumed_region[i].CoverRange(buffer_size); + PrimExpr produced_min = produced->min; + PrimExpr produced_max = produced->min + produced->extent; + PrimExpr consumed_min = consumed->min; + PrimExpr consumed_max = consumed->min + consumed->extent; + if (!analyzer->CanProve((produced_min <= consumed_min) && (consumed_max <= produced_max))) { + return false; + } + } + return true; +} + /*! * \brief Set the `StmtSRefNode::seq_index` field for stmt * \param self The schedule class @@ -137,7 +218,7 @@ class StateCreator : private StmtVisitor { private: explicit StateCreator(ScheduleStateNode* self) - : self_(self), srefs_{}, realizes_{}, block_frames_{} { + : self_(self), srefs_{}, block2realize_{}, block_frames_{} { block_frames_.emplace({}); } @@ -170,30 +251,163 @@ class StateCreator : private StmtVisitor { } void MakeBlockInfo(StmtSRef scope_root) { + bool is_root_block = srefs_.empty(); // Calculate `BlockInfo::scope` Array child_block_srefs = std::move(block_frames_.back()); BlockInfo& info = - self_->block_info.emplace(std::move(scope_root), BlockInfo(BlockScope(child_block_srefs))) + self_->block_info.emplace(scope_root, BlockInfo(BlockScope(child_block_srefs))) .first->second; - // TODO(@junrushao1994): calculate the flags // Set `affine_binding` - info.affine_binding = false; - // Set `region_cover` - info.region_cover = false; - // Set `stage_pipeline` - info.scope->stage_pipeline = false; + if (is_root_block) { + info.affine_binding = true; + } else { + info.affine_binding = + IsAffineBinding(/*realize=*/block2realize_.at(scope_root->stmt), + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(srefs_.back()), + /*analyzer=*/&analyzer_); + } + // Set `region_cover` to true, will be updated on its scope block + info.region_cover = true; + // Set `stage_pipeline` and `region_cover` for its intermediate children + info.scope->stage_pipeline = + CheckRegionCoverAndStagePipeline(info, scope_root, child_block_srefs); + } + + bool CheckRegionCoverAndStagePipeline(const BlockInfo& info, const StmtSRef& scope_root, + const Array& child_block_srefs) { + const StmtSRefNode* limit = scope_root->parent; + bool stage_pipeline = true; + // Step 1. Unbind the read/write regions of each child block + std::unordered_map> block_reads_unbound; + std::unordered_map> block_writes_unbound; + block_reads_unbound.reserve(child_block_srefs.size()); + block_writes_unbound.reserve(child_block_srefs.size()); + for (const StmtSRef& block_sref : child_block_srefs) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Map binding = GetBindings(block2realize_.at(block)); + // Step 1.1. Unbind read regions + Array reads; + reads.reserve(block->reads.size()); + for (const BufferRegion& region : block->reads) { + reads.push_back(BufferRegion(region->buffer, Substitute(region->region, binding))); + } + block_reads_unbound.emplace(block_sref.get(), std::move(reads)); + // Step 1.2. Unbind write regions + Array writes; + writes.reserve(block->writes.size()); + for (const BufferRegion& region : block->writes) { + writes.push_back(BufferRegion(region->buffer, Substitute(region->region, binding))); + } + block_writes_unbound.emplace(block_sref.get(), std::move(writes)); + } + // Step 2. For each consumer, check the region cover property + for (const auto& kv : info.scope->dst2deps) { + const StmtSRef& consumer_block_sref = kv.first; + const Array& deps = kv.second; + bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true; + // Step 2.1. Extract the path to the scope root + std::unordered_map> lca_loc; + for (const StmtSRefNode* p = consumer_block_sref.get(); p != limit; p = p->parent) { + ICHECK(p != nullptr); + lca_loc[p] = {}; + } + // Step 2.2. For each producer, find the LCA of the consumer + for (const Dependency& dep : deps) { + if (dep->kind == DepKind::kWAR || dep->kind == DepKind::kOpaque) { + stage_pipeline = false; + } + // Only care about producer-consumer relationship + if (dep->kind != DepKind::kRAW) { + continue; + } + const StmtSRef& producer = dep->src; + for (const StmtSRefNode* p = producer.get();; p = p->parent) { + ICHECK(p != nullptr); + auto it = lca_loc.find(p); + // Find the first (lowest) position in the ancestor of the consumer, + // which is the LCA by definition + if (it != lca_loc.end()) { + it->second.push_back(producer.get()); + break; + } + } + } + // Step 2.3. For each LCA, gather the produced regions, + // then check if it could cover the consumed region + for (StmtSRef lca = consumer_block_sref; region_cover && lca.get() != limit; + lca = GetRef(lca->parent)) { + const std::vector& producer_block_srefs = lca_loc.at(lca.get()); + // Skip empty LCA positions + if (producer_block_srefs.empty()) { + continue; + } + // For each buffer, record the regions generated under this loop + std::unordered_map>> touched_regions; + // Step 2.3.1. Find all the regions read by the consumer that we care about + for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) { + const BufferNode* buffer = region->buffer.get(); + touched_regions[buffer] = {}; + } + // Step 2.3.2. Find all the regions written by each producer + for (const StmtSRefNode* producer_block_sref : producer_block_srefs) { + const BlockRealize& producer_realize = block2realize_.at(producer_block_sref->stmt); + StmtSRef parent_sref = GetRef(producer_block_sref->parent); + for (const BufferRegion& region : block_writes_unbound.at(producer_block_sref)) { + const BufferNode* buffer = region->buffer.get(); + auto it = touched_regions.find(buffer); + // Skip the regions that is not read by the consumer + if (it != touched_regions.end()) { + std::vector>& touched_region = it->second; + // The analysis here is trying to be conservation to rule out false positive cases, + // and to make sure region cover property must be satisfied once the flag is on + // Therefore, we use lower-bound analysis for producers and upper-bound analysis for + // consumer, and require that the produced region can cover the consumed region + touched_region.push_back(AnalyzeRegionLowerBound(/*realize=*/producer_realize, + /*region=*/region, + /*dom_low_inclusive=*/parent_sref, + /*dom_high_exclusive=*/lca, + /*analyzer=*/&analyzer_)); + } + } + } + // Step 2.3.3. For each buffer, check the region cover property + { + StmtSRef parent_sref = GetRef(consumer_block_sref->parent); + for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) { + const BufferNode* buffer = region->buffer.get(); + const std::vector>& touched_region = touched_regions.at(buffer); + if (!touched_region.empty()) { + Array produced_region = + arith::UnionRegionLowerBound({touched_region.begin(), touched_region.end()}); + Array consumed_region = AnalyzeRegionUpperBound( + /*region=*/region, + /*dom_low_inclusive=*/parent_sref, + /*dom_high_exclusive=*/lca); + if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region, + &analyzer_)) { + region_cover = false; + break; + } + } + } + } + } + stage_pipeline = stage_pipeline && region_cover; + } + return stage_pipeline; } void VisitStmt_(const ForNode* loop) final { + analyzer_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); PushSRef(loop); VisitStmt(loop->body); PopAndRecordSRef(); } void VisitStmt_(const BlockRealizeNode* realize) final { - realizes_.push_back(realize); block_frames_.emplace_back(); const BlockNode* block = realize->block.get(); + block2realize_.emplace(block, GetRef(realize)); // Recursive visit PushSRef(block); VisitStmt(block->body); // `block->init` is not visited @@ -203,7 +417,6 @@ class StateCreator : private StmtVisitor { // Update parent scope block_frames_.pop_back(); block_frames_.back().push_back(sref); - realizes_.pop_back(); } void VisitStmt_(const SeqStmtNode* seq_stmt) final { @@ -216,10 +429,12 @@ class StateCreator : private StmtVisitor { ScheduleStateNode* self_; /*! \brief The stack frame used to indicate the current scope */ std::vector srefs_; - /*! \brief The BlockRealize in the ancestors */ - std::vector realizes_; + /*! \brief The BlockRealize corresponding to blocks */ + std::unordered_map block2realize_; /*! \brief The stack frames of blocks in the DFS visit. */ std::vector> block_frames_; + /*! \brief The auxilary analyzer */ + arith::Analyzer analyzer_; }; /**************** Constructor ****************/ @@ -227,7 +442,6 @@ class StateCreator : private StmtVisitor { ScheduleState::ScheduleState(IRModule mod, int debug_mode) { CHECK_GE(debug_mode, -1) << "ValueError: negative `debug_mode` other than -1 is not supported"; data_ = StateCreator::Create(mod, debug_mode); - (*this)->DebugVerify(); } ScheduleState::ScheduleState(PrimFunc func, int debug_mode) @@ -536,6 +750,7 @@ class SRefUpdater : public StmtVisitor { } else { // Insertion didn't take place, because the entry has been there before. // In this case, we assume that flags are still valid so intentionally keep them unchanged + new_info.scope->stage_pipeline = info.scope->stage_pipeline; info.scope = std::move(new_info.scope); } } @@ -806,29 +1021,24 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ new_map->at(g_var) = std::move(ref_new_func); this->mod = GetRef(new_mod); } - constexpr int kVerifySRefTree = static_cast(ScheduleDebugMask::kVerifySRefTree); - if (debug_mode == -1 || (debug_mode & kVerifySRefTree)) { + uint32_t flag = (debug_mode != -1) // + ? static_cast(debug_mode) // + : std::numeric_limits::max(); + if (flag & ScheduleDebugMask::kVerifySRefTree) { VerifySRefTree(GetRef(this)); } } void ScheduleStateNode::DebugVerify() const { - constexpr int kVerifySRefTree = static_cast(ScheduleDebugMask::kVerifySRefTree); - constexpr int kVerifyAffineBinding = static_cast(ScheduleDebugMask::kVerifyAffineBinding); - constexpr int kVerifyRegionCover = static_cast(ScheduleDebugMask::kVerifyRegionCover); - constexpr int kVerifyStagePipeline = static_cast(ScheduleDebugMask::kVerifyStagePipeline); ICHECK_GE(debug_mode, -1); - if (debug_mode == -1 || (debug_mode & kVerifySRefTree)) { + uint32_t flag = (debug_mode != -1) // + ? static_cast(debug_mode) // + : std::numeric_limits::max(); + if (flag & ScheduleDebugMask::kVerifySRefTree) { VerifySRefTree(GetRef(this)); } - if (debug_mode == -1 || (debug_mode & kVerifyAffineBinding)) { - // TODO(@junrushao1994): Verify affine block binding - } - if (debug_mode == -1 || (debug_mode & kVerifyRegionCover)) { - // TODO(@junrushao1994): Verify region cover - } - if (debug_mode == -1 || (debug_mode & kVerifyStagePipeline)) { - // TODO(@junrushao1994): Verify stage pipeline + if (flag & ScheduleDebugMask::kVerifyCachedFlags) { + VerifyCachedFlags(GetRef(this)); } } @@ -843,6 +1053,13 @@ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { return it->second; } +TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockInfo& info = self->GetBlockInfo(block_sref); + return {Bool(info.affine_binding), // + Bool(info.region_cover), // + Bool(info.scope->stage_pipeline)}; +} + /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); @@ -865,6 +1082,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") auto it = self->stmt2ref.find(stmt.get()); return it != self->stmt2ref.end() ? it->second : Optional(NullOpt); }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetCachedFlags").set_body_typed(GetCachedFlags); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 63ec77dcf312..b72fd8e05706 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -20,12 +20,20 @@ #define TVM_TIR_SCHEDULE_UTILS_H_ #include +#include +#include #include #include #include +#include #include #include +#include +#include + +#include "../../printer/text_printer.h" +#include "../../runtime/thread_storage_scope.h" #include "./analysis.h" namespace tvm { @@ -87,6 +95,58 @@ namespace tir { << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") +/******** Storage scope ********/ + +/*! + * \brief Determine if iterators of a storage scope should be relaxed + * under a specific thread scope + * \param storage_scope The storage scope that the iterators are on + * \param thread_scope The thread scope to be relaxed + * \return A boolean indicating the result + */ +inline bool CanRelaxStorageUndereThread(const runtime::StorageScope& storage_scope, + const runtime::ThreadScope& thread_scope) { + if (storage_scope.rank == runtime::StorageRank::kWarp) { + // for warp memory, we only relax threadIdx.x + return thread_scope.rank == 1 && thread_scope.dim_index == 0; + } + return static_cast(storage_scope.rank) <= static_cast(thread_scope.rank); +} + +/******** Integer set ********/ + +/*! + * \brief Converts the Ranges to IntSets + * \param var_dom The ranges of variables + * \return The integer sets of the variables + */ +inline Map AsIntSet(const Map& var_dom) { + std::unordered_map result; + result.reserve(var_dom.size()); + for (auto kv : var_dom) { + Var& var = kv.first; + Range& range = kv.second; + result.emplace(std::move(var), arith::IntSet::FromRange(std::move(range))); + } + return {result.begin(), result.end()}; +} + +/*! + * \brief Converts an N-dimensional integer set to N-dimensional region + * \param nd_int_set The integer set + * \return The region as the result of conversion + */ +inline Array AsRegion(const Array& nd_int_set, arith::Analyzer* analyzer) { + Array result; + result.reserve(nd_int_set.size()); + for (const arith::IntSet& int_set : nd_int_set) { + PrimExpr min = analyzer->Simplify(int_set.min()); + PrimExpr extent = analyzer->Simplify(int_set.max() - int_set.min() + 1); + result.push_back(Range::FromMinExtent(std::move(min), std::move(extent))); + } + return result; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index a5ca67eaa036..edbafe27cf13 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -29,9 +29,9 @@ #include -#include "../../runtime/thread_storage_scope.h" #include "../../support/arena.h" #include "../../support/utils.h" +#include "../schedule/utils.h" namespace tvm { namespace tir { @@ -280,15 +280,10 @@ class BufferAccessRegionCollector : public StmtExprVisitor { return false; } ICHECK(loop->thread_binding.defined()); - IterVar binding = loop->thread_binding.value(); - runtime::ThreadScope ts = runtime::ThreadScope::Create(binding->thread_tag); - + const String& thread_tag = loop->thread_binding.value()->thread_tag; // When there is warp memory // threadIdx.x must be set to be warp index. - if (scope.rank == runtime::StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) { - return true; - } - return static_cast(scope.rank) <= ts.rank; + return CanRelaxStorageUndereThread(scope, runtime::ThreadScope::Create(thread_tag)); } /**************** Class members ****************/ diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 5c4cc9491cb5..8f07df56a48b 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -123,6 +123,106 @@ def test_select(): ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11)) +def test_region_lower_bound_not_independent(): + i = tvm.tir.Var("i", "int32") + result = tvm.arith.estimate_region_lower_bound( + region=[ + tvm.ir.Range(begin=i, end=i + 2), + tvm.ir.Range(begin=i + 1, end=i + 4), + ], + var_dom={ + i: tvm.ir.Range(begin=0, end=64), + }, + predicate=tvm.tir.IntImm("bool", 1), + ) + assert result is None + + +def test_region_lower_bound_stride_too_wide(): + i = tvm.tir.Var("i", "int32") + result = tvm.arith.estimate_region_lower_bound( + region=[ + tvm.ir.Range(begin=i * 4, end=i * 4 + 2), + ], + var_dom={ + i: tvm.ir.Range(begin=0, end=64), + }, + predicate=tvm.tir.IntImm("bool", 1), + ) + assert result is None + + +def test_region_lower_bound_small_stride(): + i = tvm.tir.Var("i", "int32") + (result,) = tvm.arith.estimate_region_lower_bound( + region=[ + tvm.ir.Range.from_min_extent(min_value=i * 4, extent=8), + ], + var_dom={ + i: tvm.ir.Range(begin=0, end=64), + }, + predicate=tvm.tir.IntImm("bool", 1), + ) + assert result.min_value.value == 0 + assert result.max_value.value == 259 + + +def test_region_lower_bound_split_predicate(): + x_o = tvm.tir.Var("xo", "int32") + x_i = tvm.tir.Var("xi", "int32") + x = x_o * 4 + x_i + (result,) = tvm.arith.estimate_region_lower_bound( + region=[ + tvm.ir.Range.from_min_extent(min_value=x * 4, extent=8), + ], + var_dom={ + x_o: tvm.ir.Range(begin=0, end=16), + x_i: tvm.ir.Range(begin=0, end=4), + }, + predicate=x < 63, + ) + assert result.min_value.value == 0 + assert result.max_value.value == 255 + + +def test_region_lower_bound_multiple_variables(): + div = tvm.tir.floordiv + mod = tvm.tir.floormod + x = tvm.tir.Var("x", "int32") + wid = tvm.tir.Var("wid", "int32") + i = div(x, 16) + j = div(mod(x, 16), 4) * 8 + mod(x, 4) + div(wid, 32) * 4 + k = wid % 32 + (i_int_set, j_int_set, k_int_set) = tvm.arith.estimate_region_lower_bound( + region=[ + tvm.ir.Range.from_min_extent(min_value=i, extent=1), + tvm.ir.Range.from_min_extent(min_value=j, extent=1), + tvm.ir.Range.from_min_extent(min_value=k, extent=1), + ], + var_dom={ + x: tvm.ir.Range(begin=0, end=32), + wid: tvm.ir.Range(begin=0, end=64), + }, + predicate=tvm.tir.IntImm("bool", 1), + ) + assert i_int_set.min_value.value == 0 + assert i_int_set.max_value.value == 1 + assert j_int_set.min_value.value == 0 + assert j_int_set.max_value.value == 31 + assert k_int_set.min_value.value == 0 + assert k_int_set.max_value.value == 31 + + +def test_union_lower_bound(): + neg_inf = tvm.arith.int_set.neg_inf() + pos_inf = tvm.arith.int_set.pos_inf() + set_0 = tvm.arith.IntervalSet(min_value=neg_inf, max_value=0) + set_1 = tvm.arith.IntervalSet(min_value=1, max_value=pos_inf) + result = tvm.arith.int_set.union_lower_bound([set_0, set_1]) + assert result.min_value.same_as(neg_inf) + assert result.max_value.same_as(pos_inf) + + if __name__ == "__main__": test_basic() test_vector() @@ -131,3 +231,9 @@ def test_select(): test_max_min() test_select() test_mod() + test_region_lower_bound_not_independent() + test_region_lower_bound_stride_too_wide() + test_region_lower_bound_small_stride() + test_region_lower_bound_split_predicate() + test_region_lower_bound_multiple_variables() + test_union_lower_bound() diff --git a/tests/python/unittest/test_tir_block_scope.py b/tests/python/unittest/test_tir_schedule_block_scope.py similarity index 100% rename from tests/python/unittest/test_tir_block_scope.py rename to tests/python/unittest/test_tir_schedule_block_scope.py diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py index ac98725ef9f8..34041120f252 100644 --- a/tests/python/unittest/test_tir_schedule_state.py +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -189,7 +189,8 @@ def test_replace_partial_copy0(): sref = s.get_sref(s.mod["main"].body.block.body[0].body) other_part_hash = s.mod["main"].body.block.body[1].__hash__() s.replace(sref, target) - # The stmt is held by `hold_sref`, so it will be coped in copy-on-write because the ref count is not unique + # The stmt is held by `hold_sref`, so it will be coped in copy-on-write + # because the ref count is not unique assert ref_old_hash != s.mod["main"].body.block.body[0].__hash__() assert not tvm.ir.structural_equal(hold_ref.body, target) # The function and the other part stmt can be directly written diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py new file mode 100644 index 000000000000..a320812b339f --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -0,0 +1,669 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring + +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.state import CachedFlags +from tvm.tir.stmt_functor import post_order_visit + +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +@tvm.script.tir +def elementwise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = 0.0 + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + B = tir.match_buffer(b, (128, 128), "float32") + with tir.block([128], "B") as vi: + tir.reads([A[0:128, 0:128]]) + tir.writes([B[0:128, 0:128]]) + B[vi, 0] = A[vi, 0] + if A[vi, 0] == 0.0: + with tir.block([], "C"): + tir.reads([A[0:128, 0:128]]) + tir.writes([B[0:128, 0:128]]) + with tir.block([128], "D") as vj: + B[vi, vj] = A[vi, vj] * 3.0 + else: + with tir.block([], "E"): + tir.reads([A[0:128, 0:128]]) + tir.writes([B[0:128, 0:128]]) + with tir.block([128], "F") as vj: + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def write_after_read(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def loop_carried_dependency(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128,)) + B = tir.match_buffer(b, (128,)) + C = tir.match_buffer(c, (128,)) + for i in range(0, 128): + with tir.block([128], "B") as vi: + B[vi] = A[vi] * 2.0 + with tir.block([128], "C") as vi: + C[vi] = tir.if_then_else(vi >= 1, B[vi - 1] + 1.0, 0.0, dtype="float32") + + +@tvm.script.tir +def concatenate_multi_producer(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128,)) + B = tir.match_buffer(b, (128,)) + for i in range(0, 64): + with tir.block([64], "A_0") as vi: + A[vi] = vi + 1 + for i in range(0, 64): + with tir.block([64], "A_1") as vi: + tir.bind(vi, i + 64) + A[vi] = vi + 2 + with tir.block([128], "B") as vi: + B[vi] = A[vi] * 2.0 + + +@tvm.script.tir +def concatenate_multi_producer_uncovered(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128,)) + B = tir.match_buffer(b, (128,)) + for i in range(0, 63): + with tir.block([63], "A_0") as vi: + A[vi] = vi + 1 + for i in range(0, 64): + with tir.block([64], "A_1") as vi: + tir.bind(vi, i + 64) + A[vi] = vi + 2 + with tir.block([128], "B") as vi: + B[vi] = A[vi] * 2.0 + + +@tvm.script.tir +def lca_at_loop(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128,)) + B = tir.match_buffer(b, (128,)) + C = tir.match_buffer(c, (128,)) + for i in range(0, 128): + with tir.block([128], "B") as vi: + B[vi] = A[vi] * 2.0 + with tir.block([128], "C") as vi: + C[vi] = B[vi] + 1.0 + + +@tvm.script.tir +def multi_producer_consumer(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128,)) + B = tir.match_buffer(b, (128,)) + for i in range(0, 64): + with tir.block([64], "A_0") as vi: + A[vi] = vi + 1 + for i in range(0, 64): + with tir.block([64], "A_1") as vi: + tir.bind(vi, i + 64) + A[vi] = vi + 2 + for i in range(0, 64): + with tir.block([64], "B_0") as vi: + B[vi] = A[vi] + 2.0 + for i in range(0, 64): + with tir.block([64], "B_1") as vi: + tir.bind(vi, i + 64) + B[vi] = A[vi] + 3.0 + + +@tvm.script.tir +def elementwise_affine_producer(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + for i, j, k, l in tir.grid(16, 2, 32, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i * 8 + j * 4 + k // 8) + tir.bind(vj, k % 8 * 16 + l) + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_subblock(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + with tir.block([32, 32], "B") as [vi, vj]: + tir.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + tir.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + with tir.block([4, 4], "B_sub") as [vi_i, vj_i]: + B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def elementwise_subblock_uncovered(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + with tir.block([32, 32], "B") as [vi, vj]: + tir.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + tir.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + with tir.block([2, 2], "B_sub") as [vi_i, vj_i]: + B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.tir +def bound_to_thread(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + B = tir.alloc_buffer([128, 128], scope="shared") + for i in tir.thread_binding(0, 128, thread="threadIdx.x"): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + for j in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi, vj]: + C[vj, vi] = B[vj, vi] + 1.0 + + +@tvm.script.tir +def equal_ranked_threads(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + B = tir.alloc_buffer([128, 128], scope="shared") + for i_o in tir.thread_binding(0, 16, thread="threadIdx.x"): + for i_i in tir.thread_binding(0, 8, thread="threadIdx.y"): + for j in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i_o * 8 + i_i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + for j in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, i_o * 8 + i_i) + tir.bind(vj, j) + C[vj, vi] = B[vj, vi] + 1.0 + + +@tvm.script.tir +def warp_memory(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + B = tir.alloc_buffer([128, 4, 32], scope="warp") + for i_o in tir.thread_binding(0, 4, thread="threadIdx.y"): + for i_i in tir.thread_binding(0, 32, thread="threadIdx.x"): + for j in tir.serial(0, 128): + with tir.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: + B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 + for j in tir.serial(0, 128): + with tir.block([4, 32, 128], "C") as [warp_id, lane_id, vj]: + C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 + + +@tvm.script.tir +def warp_memory_negative(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + B = tir.alloc_buffer([128, 4, 32], scope="warp") + for i_o in tir.thread_binding(0, 4, thread="threadIdx.y"): + for i_i in tir.thread_binding(0, 32, thread="threadIdx.x"): + for j in tir.serial(0, 128): + with tir.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: + B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 + for i_o_prime in tir.thread_binding(0, 4, thread="threadIdx.y"): + for j in tir.serial(0, 128): + with tir.block([4, 32, 4, 128], "C") as [_warp_id, lane_id, warp_id, vj]: + C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 + + +# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +def _get_block(s: tir.ScheduleState, name_hint: str) -> tir.StmtSRef: + result = None + + def f_visit(node): + nonlocal result + if isinstance(node, tvm.tir.Block) and node.name_hint == name_hint: + result = node + + func = s.mod["main"] + post_order_visit(func.body, f_visit) + assert result is not None and isinstance(result, tvm.tir.Block) + return s.get_sref(result) + + +def test_elementwise(): + s = tir.ScheduleState(elementwise, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_matmul(): + s = tir.ScheduleState(matmul, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "init")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "update")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_block_in_opaque_block(): + s = tir.ScheduleState(block_in_opaque_block, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "E")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "F")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_write_after_read(): + s = tir.ScheduleState(write_after_read, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=False, + ) + # pylint: enable=protected-access + + +def test_loop_carried_dependency(): + s = tir.ScheduleState(loop_carried_dependency, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=False, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=False, + ) + # pylint: enable=protected-access + + +def test_concatenate_multi_producer_covered(): # pylint: disable=invalid-name + s = tir.ScheduleState(concatenate_multi_producer, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_concatenate_multi_producer_uncovered(): # pylint: disable=invalid-name + s = tir.ScheduleState(concatenate_multi_producer_uncovered, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=False, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=False, + ) + # pylint: enable=protected-access + + +def test_lca_at_loop(): + s = tir.ScheduleState(lca_at_loop, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_multi_producer_consumer(): + s = tir.ScheduleState(multi_producer_consumer, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "A_0")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "A_1")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B_0")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B_1")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_elementwise_affine_producer(): + s = tir.ScheduleState(elementwise_affine_producer, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_subblock(): + s = tir.ScheduleState(elementwise_subblock, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B_sub")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_subblock_uncovered(): + s = tir.ScheduleState(elementwise_subblock_uncovered, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=False, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B_sub")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=False, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_thread_binding(): + s = tir.ScheduleState(bound_to_thread, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_equal_ranked_threads(): + s = tir.ScheduleState(equal_ranked_threads, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_warp_memory(): + s = tir.ScheduleState(warp_memory, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +def test_warp_memory_negative(): + s = tir.ScheduleState(warp_memory_negative, debug_mode=True) + # pylint: disable=protected-access + assert s._get_cached_flags(_get_block(s, "root")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=False, + ) + assert s._get_cached_flags(_get_block(s, "B")) == CachedFlags( + affine_binding=True, + region_cover=True, + stage_pipeline=True, + ) + assert s._get_cached_flags(_get_block(s, "C")) == CachedFlags( + affine_binding=True, + region_cover=False, + stage_pipeline=True, + ) + # pylint: enable=protected-access + + +if __name__ == "__main__": + test_elementwise() + test_matmul() + test_block_in_opaque_block() + test_write_after_read() + test_loop_carried_dependency() + test_concatenate_multi_producer_covered() + test_concatenate_multi_producer_uncovered() + test_lca_at_loop() + test_multi_producer_consumer() + test_elementwise_affine_producer() + test_subblock() + test_subblock_uncovered() + test_thread_binding() + test_equal_ranked_threads() + test_warp_memory() + test_warp_memory_negative()