Skip to content

Commit

Permalink
[TensorIR][M2a] Verification of cached flags (#8114)
Browse files Browse the repository at this point in the history
* [TensorIR][M2a] Verification of cached flags

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

* Address comments

* Update src/tir/schedule/analysis/verify.cc

Co-authored-by: Cody Yu <comaniac0422@gmail.com>

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Cody Yu <comaniac0422@gmail.com>
  • Loading branch information
7 people authored May 25, 2021
1 parent 32608d9 commit 47c8e47
Show file tree
Hide file tree
Showing 24 changed files with 1,615 additions and 76 deletions.
48 changes: 46 additions & 2 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ using tir::IterVar;
using tir::Var;
using tir::VarNode;

class Analyzer;

//-----------------------------------------------
// Integer set data structure.
//
Expand Down Expand Up @@ -190,6 +192,14 @@ IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_m
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes Array<Range>
*
* \param region The range to be evaluated.
* \param dom_map The domain of each variable.
* \return An array of integer sets that can cover all the possible values.
*/
Array<IntSet> EvalSet(const Array<Range>& region, const Map<Var, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual>;
/*!
Expand All @@ -204,19 +214,53 @@ ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \brief Create a union set of all sets, possibly relaxed
* \param sets The sets to be combined
* \return the set after union
*/
IntSet Union(const Array<IntSet>& sets);

/*!
* \brief The union of N-dimensional integer sets
* \param nd_int_sets A list of N-dimensional integer sets
* \return An N-dimensional integer set as the result of union
*/
Array<IntSet> UnionRegion(const Array<Array<IntSet>>& nd_int_sets);

/*!
* \brief Create a lower-bound of union set, where some of the segments may be dropped
* \param sets The sets to be combined
* \return the set after union
*/
IntSet UnionLowerBound(const Array<IntSet>& sets);

/*!
* \brief The union of N-dimensional integer sets
* \param nd_int_sets A list of N-dimensional integer sets
* \return An N-dimensional integer set as the result of union
*/
Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets);

/*!
* \brief Create an union set of all sets
* \param sets The sets to be intersected
* \return the set after intersected
*/
IntSet Intersect(const Array<IntSet>& sets);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate
* \param region The region to be analyzed
* \param var_dom The ranges of the variables
* \param predicate The predicate for the affine map
* \param analyzer The analyzer used
* \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis
*/
TVM_DLL Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
const Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SET_H_
4 changes: 3 additions & 1 deletion include/tvm/tir/schedule/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ class BlockScopeNode : public Object {
* equivalent to of a stage pipeline. Under the following conditions:
*
* 1) The region cover property holds for every of its child blocks
* 2) No write-after-read dependency
* 2) No write-after-read dependency or opaque dependency, only read-after-write and
* write-after-write are allowed
* 3) All the statements in the scope are schedulable statements, i.e. Block and For
*/
bool stage_pipeline{false};

Expand Down
4 changes: 1 addition & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ class Schedule : public runtime::ObjectRef {
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyAffineBinding
* 3) VerifyRegionCover
* 4) VerifyStagePipeline
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
Expand Down
15 changes: 5 additions & 10 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,11 @@ struct BlockInfo {
* \brief The bitmask of the debug flag in the ScheduleStateNode.
* \sa ScheduleStateNode
*/
enum class ScheduleDebugMask : int32_t {
enum ScheduleDebugMask : uint32_t {
/*! \brief Verify the correctness of the sref tree */
kVerifySRefTree = 1,
/*! \brief Verify the correctness of affine_binding */
kVerifyAffineBinding = 2,
/*! \brief Verify the correctness of region_cover */
kVerifyRegionCover = 4,
/*! \brief Verify the correctness of stage_pipeline */
kVerifyStagePipeline = 8,
/*! \brief Verify the correctness of affine_binding, region_cover and stage_pipeline */
kVerifyCachedFlags = 2,
};

/*!
Expand Down Expand Up @@ -135,9 +131,8 @@ class ScheduleStateNode : public Object {
/*!
* \brief Trigger the verification according to the `debug_mode` bitmask.
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
* 2) If the bitmask `kVerifyAffineBinding` is on, verify the correctness of `affine_binding`
* 3) If the bitmask `kVerifyRegionCover` is on, verify the correctness of `region_cover`
* 4) If the bitmask `kVerifyStagePipeline` is on, verify the correctness of `stage_pipeline`
* 2) If the bitmask `kVerifyCachedFlags` is on, verify the correctness of `affine_binding`,
* `region_cover` and `stage_pipeline`
*/
TVM_DLL void DebugVerify() const;

Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ TVM_DLL Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var& v
*/
TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var& var)> vmap);

/*!
* \brief Substitute the var specified by vmap.
* \param region The object whose vars are to be substituted
* \param vmap The map of new values.
* \return The result.
*/
TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, PrimExpr>& vmap);

/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Integer bound analysis, simplification and pattern detection."""

from .int_set import IntSet, IntervalSet
from .int_set import IntSet, IntervalSet, estimate_region_lower_bound
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
Expand Down
60 changes: 60 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,63 @@ class IntervalSet(IntSet):

def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value)


def estimate_region_lower_bound(region, var_dom, predicate):
"""Analyze the region with affine map, given the domain of variables and their predicate
Parameters
----------
region : List[Range]
The region to be analyzed.
var_dom : Dict[Var, Range]
The ranges of the variables
predicate : PrimExpr
The predicate for the affine map
Returns
----------
region_int_set : Optional[List[IntSet]]
None if the detection fails, or an array of IntSets as the result of analysis
"""
return _ffi_api.EstimateRegionLowerBound(region, var_dom, predicate)


def pos_inf():
"""Returns the symbolic positive infinity
Returns
----------
pos_inf : Var
A symbolic var that indicates positive infinity
"""
return _ffi_api.PosInf()


def neg_inf():
"""Returns the symbolic positive infinity
Returns
----------
neg_inf : Var
A symbolic var that indicates positive infinity
"""
return _ffi_api.NegInf()


def union_lower_bound(sets):
"""Create a lower-bound of union set, where some of the segments may be dropped
Parameters
----------
sets : List[IntSet]
The sets to be combined
Returns
----------
union_lower_bound : List[IntSet]
An N-dimensional integer set, the lower bound of the union
"""
return _ffi_api.UnionLowerBound(sets)
4 changes: 1 addition & 3 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def __init__(
----------
The checks performed includes:
1) VerifySRefTree
2) VerifyAffineBinding
3) VerifyRegionCover
4) VerifyStagePipeline
2) VerifyCachedFlags
"""
if isinstance(debug_mode, bool):
if debug_mode:
Expand Down
45 changes: 36 additions & 9 deletions python/tvm/tir/schedule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""This file defines ScheduleState, the core data structure of TensorIR scheduling."""
from collections import namedtuple
from enum import IntEnum
from typing import Dict, Optional, Union

Expand All @@ -26,6 +27,8 @@
from . import _ffi_api_schedule
from .block_scope import BlockScope, StmtSRef

CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover", "stage_pipeline"])


class ScheduleDebugMask(IntEnum):
"""The bitmask of the `debug_mode` flag in the ScheduleState class.
Expand All @@ -38,18 +41,12 @@ class ScheduleDebugMask(IntEnum):
----------
VERIFY_SREF_TREE : int = 1
Verify the correctness of the sref tree
VERIFY_AFFINE_BINDING : int = 2
Verify the correctness of affine_binding
VERIFY_REGION_COVER : int = 4
Verify the correctness of region_cover
VERIFY_STAGE_PIPELINE: int = 8
Verify the correctness of stage_pipeline
VERIFY_CACHED_FLAGS : int = 2
Verify the correctness of affine_binding, region_cover and stage_pipeline
"""

VERIFY_SREF_TREE = 1
VERIFY_AFFINE_BINDING = 2
VERIFY_REGION_COVER = 4
VERIFY_STAGE_PIPELINE = 8
VERIFY_CACHED_FLAGS = 2


@register_object("tir.ScheduleState")
Expand Down Expand Up @@ -140,6 +137,36 @@ def get_block_scope(self, block_sref: StmtSRef) -> BlockScope:
self, block_sref
)

def _get_cached_flags(self, block_sref: StmtSRef) -> CachedFlags:
"""Get the cached flags of the corresponding block
Parameters
----------
block_sref : StmtSRef
The block sref to be retrieved
Returns
-------
flags : CachedFlags
Three flags: affine_binding, region_cover, stage_pipeline
Note
-------
It is an API intended for internal testing use.
"""
(
affine_binding,
region_cover,
stage_pipeline,
) = _ffi_api_schedule.ScheduleStateGetCachedFlags( # pylint: disable=no-member
self, block_sref
)
return CachedFlags(
affine_binding=bool(affine_binding.value),
region_cover=bool(region_cover.value),
stage_pipeline=bool(stage_pipeline.value),
)

def replace(
self,
src_sref: StmtSRef,
Expand Down
Loading

0 comments on commit 47c8e47

Please sign in to comment.