Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Merge surjective/non-surjective iter mapping detections #11287

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 52 additions & 62 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
};

/*! \brief Mapping level for iterators. */
enum IterMapLevel {
wrongtest-intellif marked this conversation as resolved.
Show resolved Hide resolved
// Require the mapping to be bijective.
Bijective = 0,
// Require the mapping to be surjective.
Surjective = 1,
// No mapping safety check.
NoCheck = 3
};

/*!
* \brief Detect if indices can be written as
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
*
* Here y = some-quasi-affine-iter-map(input_iters)
* and c are symbolic constants.
*
* We also requires that y_i and y_j to be independent for i != j.
*
* For returned value rv, the following is always true:
* - rv[i]->args.size() <=1: only one iterator per element.
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param simplify_trivial_iterators If true, iterators with extent of
* 1 will be replaced with a constant value.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
* \brief Result of DetectIterMap.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
class IterMapResultNode : public Object {
public:
// The detected pattern if a match exists.
Array<IterSumExpr> indices;

/*! \brief A utility struct for return values from DetectPaddedIterMap
*/
struct PaddedIterMapResult {
// Any errors that occurred while converting the input indices. If
// the array is empty, the conversion was successful.
Array<String> errors;

// The detected pattern if a match exists.
Array<IterSumExpr> indices;

/* \brief Boolean expression indicating if padding was required
*
* `requires_padding` evaluates to true if the returned indices
* contain padding relative to the provided expressions, and false
* otherwise. If `input_iters` contains a variable extent, this
* expression may be in terms of those variables.
*/
PrimExpr requires_padding;

/* \brief Boolean expression indicating if a specific value w
/*! \brief Boolean expression indicating if a specific value w
*
* `padding_predicate` evaluates to true for a set of indices that
* are outside the bounds of the provided index iterators, but
Expand All @@ -314,56 +290,70 @@ struct PaddedIterMapResult {
* `input_iters`.
*/
PrimExpr padding_predicate;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("errors", &errors);
v->Visit("indices", &indices);
v->Visit("padding_predicate", &padding_predicate);
}

static constexpr const char* _type_key = "arith.IterMapResult";
TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object);
};

/*!
* \brief Managed reference to IterMapResultNode.
* \sa IterMapResultNode
*/
class IterMapResult : public ObjectRef {
public:
// constructor
IterMapResult() { data_ = make_object<IterMapResultNode>(); }

/*! \return mutable pointers to the node. */
IterMapResultNode* operator->() const { return static_cast<IterMapResultNode*>(get_mutable()); }
};

/*!
* \brief Detect if indices can be written as
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
*
* Here y = some-quasi-affine-iter-map(input_iters) and c are
* symbolic constants. The y_i iterators may be padded to fit this
* representation.
* Here y = some-quasi-affine-iter-map(input_iters)
* and c are symbolic constants.
*
* We also requires that y_i and y_j to be independent for i != j.
*
* For returned value rv, the following is always true:
* - rv.indices[i]->args.size() <=1: only one iterator per element.
* - rv[i]->args.size() <=1: only one iterator per element.
*
* \param indices The indices to detect pattern for.
*
* \param input_iters Map from variable to iterator's range.
*
* \param predicate The predicate constraints on the input iterators
*
* \param require_bijective A boolean flag that indicates whether the
* mapping should be bijective. If true, no padding may be
* introduced.
*
* \param check_level The iter mapping checking level.
* \param analyzer Analyzer used to get context information.
*
* \param simplify_trivial_iterators If true, iterators with extent of
* 1 will be replaced with a constant value.
*
* \return An instance of PaddedIterMapResult.
* \return The detected iteration result.
* The return object's .indices is empty on failure.
*/
PaddedIterMapResult DetectPaddedIterMap(const Array<PrimExpr>& indices,
const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer,
bool simplify_trivial_iterators = true);
IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, IterMapLevel check_level,
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);

/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param input_pred The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param check_level The iter mapping checking level.
*
* \return The indices after rewrite
*/
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective);
const PrimExpr& input_pred, IterMapLevel check_level);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
Expand Down Expand Up @@ -403,7 +393,7 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
* \param input_iters Map from variable to iterator's range.
* \param sub_iters Iterators of subspace.
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param check_level The iter mapping checking level.
* \param analyzer Analyzer used to get context information.
*
* \return The result list has length len(bindings) + 1
Expand All @@ -416,7 +406,7 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective, arith::Analyzer* analyzer);
IterMapLevel check_level, arith::Analyzer* analyzer);

/*!
* \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr.
Expand Down
53 changes: 43 additions & 10 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
""" Iterator (quasi)affine mapping patterns."""
from enum import IntEnum
import tvm._ffi
from tvm.runtime import Object
from tvm.ir import PrimExpr
Expand Down Expand Up @@ -88,11 +89,35 @@ def __init__(self, args, base):
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)


class IterMapLevel(IntEnum):
"""Possible kinds of iter mapping check level."""

Bijective = 0
Surjective = 1
NoCheck = 3

@staticmethod
def from_str(name: str):
"""Helper to create level enum from string"""
if name is None:
return IterMapLevel.NoCheck
name = name.lower()
if name == "bijective":
check_level = IterMapLevel.Bijective
elif name == "surjective":
check_level = IterMapLevel.Surjective
elif name == "nocheck":
check_level = IterMapLevel.NoCheck
else:
raise ValueError(f"Unknown check level {name}")
return check_level


def detect_iter_map(
indices,
input_iters,
predicate=True,
require_bijective=False,
check_level=IterMapLevel.Surjective,
simplify_trivial_iterators=True,
):
"""Detect if indices can be written as mapped iters from input iters
Expand All @@ -108,22 +133,26 @@ def detect_iter_map(
predicate : PrimExpr
The predicate constraints on the input iterators

require_bijective : bool
A boolean flag that indicates whether the mapping should be bijective
check_level : Union[str, IterMapLevel]
Checking level of iteration mapping

simplify_trivial_iterators: bool
If true, iterators with extent of 1 will be replaced with a
constant value.

Returns
-------
results : List[IterSumExpr]
results : IterMapResult
The iter map matching result.
Empty array if no match can be found.
The result's .indices is empty array if no match can be found.

"""
if isinstance(check_level, str):
check_level = IterMapLevel.from_str(check_level)
elif check_level is None:
check_level = IterMapLevel.NoCheck
return _ffi_api.DetectIterMap(
indices, input_iters, predicate, require_bijective, simplify_trivial_iterators
indices, input_iters, predicate, check_level, simplify_trivial_iterators
)


Expand All @@ -143,7 +172,9 @@ def normalize_iter_map_to_expr(expr):
return _ffi_api.NormalizeIterMapToExpr(expr)


def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False):
def subspace_divide(
bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective
):
"""Detect if bindings can be written as
[a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
Expand Down Expand Up @@ -172,8 +203,8 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi
predicate : PrimExpr
The predicate constraints on the input iterators

require_bijective : bool
A boolean flag that indicates whether the bindings should be bijective
check_level : Union[str, IterMapLevel]
Checking level of iteration mapping

Returns
-------
Expand All @@ -185,7 +216,9 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi
len(bindings): the predicate of outer space and inner space
Empty array if no match can be found.
"""
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective)
if isinstance(check_level, str):
check_level = IterMapLevel.from_str(check_level)
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, check_level)


def inverse_affine_iter_map(iter_map, outputs):
Expand Down
5 changes: 3 additions & 2 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -867,9 +867,10 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
for (const Range& range : region) {
affine_indices.push_back(range->min);
}
iter_sum_exprs = DetectIterMap(
auto res = DetectIterMap(
/*indices=*/affine_indices, /*input_iters=*/var_dom,
/*predicate=*/predicate, /*require_bijective=*/false, analyzer);
/*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer);
iter_sum_exprs = res->indices;
}
if (iter_sum_exprs.empty()) {
return NullOpt;
Expand Down
Loading