Skip to content

Commit

Permalink
[MetaSchedule] Arithmetic analysis (#10403)
Browse files Browse the repository at this point in the history
This PR changes the normal form of the affine detector and supports a single var predicate. It also enhances ModularSet detector to enable floor mod patterns.
  • Loading branch information
spectrometerHBH authored Feb 28, 2022
1 parent 9ca2139 commit 7127296
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 94 deletions.
175 changes: 100 additions & 75 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator {
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
}

IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min,
const PrimExpr& predicate_induced_max) {
IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
const Optional<PrimExpr>& predicate_induced_min,
const Optional<PrimExpr>& predicate_induced_max) {
return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
predicate_induced_max);
}
Expand Down Expand Up @@ -494,16 +495,17 @@ class IterMapRewriter : public ExprMutator {
* \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
* \return The Normalized expression.
*/
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min,
PrimExpr predicate_induced_max) {
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional<PrimExpr> predicate_induced_min,
Optional<PrimExpr> predicate_induced_max) {
// normalize to zero base
PrimExpr base = expr->base;
if (!is_zero(base)) {
expr.CopyOnWrite()->base = 0;
if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
if (predicate_induced_min.defined())
predicate_induced_min = predicate_induced_min.value() - base;
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
if (expr->args.size() < 1) return expr;
Optional<IterSumExpr> opt = TryFuseIters(expr);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
Expand All @@ -522,10 +524,10 @@ class IterMapRewriter : public ExprMutator {
PrimExpr iter_min = mark_offset;
PrimExpr iter_max = iter_min + mark->extent;
if (predicate_induced_min.defined()) {
iter_min = max(predicate_induced_min, iter_min);
iter_min = max(predicate_induced_min.value(), iter_min);
}
if (predicate_induced_max.defined()) {
iter_max = min(predicate_induced_max, iter_max);
iter_max = min(predicate_induced_max.value(), iter_max);
}
if (!is_zero(iter_min)) {
// structured form's offset should be updated
Expand All @@ -536,7 +538,6 @@ class IterMapRewriter : public ExprMutator {
}
mark.CopyOnWrite()->extent = iter_max - iter_min;
sum_fuse_map_[flattened_form] = {mark, iter_min};

// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
Expand Down Expand Up @@ -771,14 +772,15 @@ class IterMapRewriter : public ExprMutator {
struct IterConstraint {
// The expr of the iter
PrimExpr iter;
// The expr of the lower_bound
PrimExpr lower_bound;
// The expr of the upper_bound
PrimExpr upper_bound;
// The expr of the lower_bound, maybe undefined
Optional<PrimExpr> lower_bound;
// The expr of the upper_bound, maybe undefined
Optional<PrimExpr> upper_bound;
// The size of the iter, which is the number of nodes
size_t expr_size = 0;

IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size)
IterConstraint(PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound,
size_t size)
: iter(std::move(iter)),
lower_bound(std::move(lower_bound)),
upper_bound(std::move(upper_bound)),
Expand All @@ -788,11 +790,12 @@ struct IterConstraint {
/*!
* \brief Split the predicate into `(a < b) && (c < d) && ...`
* \param pred The predicate to be split.
* \param input_iters The input iterators.
* \param result The result of predicate split.
* \return A list of IterConstraint, empty if the split failed.
*/
std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
const Map<Var, Range>& input_iters) {
std::vector<IterConstraint> result;
bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
std::vector<IterConstraint>* result) {
arith::PVar<PrimExpr> lhs, rhs, rest;
for (;;) {
// try extract comparisions
Expand Down Expand Up @@ -821,78 +824,94 @@ std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
is_equal = true;
is_finish = true;
} else {
return std::vector<IterConstraint>();
return false;
}
PrimExpr lhs_expr = lhs.Eval();
PrimExpr rhs_expr = rhs.Eval();
// we only accept predicate of integers
if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
(rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
return std::vector<IterConstraint>();
return false;
}
// determine iter and bound, if we can not distinguish them simply,
// try divide (lhs - rhs) into itervar aware and itervar free parts
auto f_use_itervar = [&input_iters](const VarNode* v) {
return input_iters.count(GetRef<Var>(v));
return input_iters->count(GetRef<Var>(v));
};
bool bound_at_left;
if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
bound_at_left = true;
} else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
bound_at_left = false;
} else {
bound_at_left = false; // accumulate bound to rhs
PrimExpr sum_parts = lhs_expr - rhs_expr;
lhs_expr = 0;
rhs_expr = 0;
std::function<void(const PrimExpr&, bool)> f_extract =
[&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
if (const AddNode* add = part.as<AddNode>()) {
f_extract(add->a, sign);
f_extract(add->b, sign);
} else if (const SubNode* sub = part.as<SubNode>()) {
f_extract(sub->a, sign);
f_extract(sub->b, !sign);
} else if (UsesVar(part, f_use_itervar)) {
lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
} else {
rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
}
};
f_extract(sum_parts, true);
arith::Analyzer analyzer;
lhs_expr = analyzer.Simplify(lhs_expr);
rhs_expr = analyzer.Simplify(rhs_expr);
}
PrimExpr lower_bound, upper_bound, iter;
if (is_greater) {
if (bound_at_left) {
// bound > iter
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
iter = rhs_expr;
if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) {
// At least it uses one input iter
if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
bound_at_left = true;
} else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
bound_at_left = false;
} else {
// iter > bound
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
iter = lhs_expr;
bound_at_left = false; // accumulate bound to rhs
PrimExpr sum_parts = lhs_expr - rhs_expr;
lhs_expr = 0;
rhs_expr = 0;
std::function<void(const PrimExpr&, bool)> f_extract =
[&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
if (const AddNode* add = part.as<AddNode>()) {
f_extract(add->a, sign);
f_extract(add->b, sign);
} else if (const SubNode* sub = part.as<SubNode>()) {
f_extract(sub->a, sign);
f_extract(sub->b, !sign);
} else if (UsesVar(part, f_use_itervar)) {
lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
} else {
rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
}
};
f_extract(sum_parts, true);
arith::Analyzer analyzer;
lhs_expr = analyzer.Simplify(lhs_expr);
rhs_expr = analyzer.Simplify(rhs_expr);
}
} else {
if (bound_at_left) {
// bound < iter
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
iter = rhs_expr;
Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
PrimExpr iter;
if (is_greater) {
if (bound_at_left) {
// bound > iter / bound >= iter
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
iter = rhs_expr;
} else {
// iter > bound / iter >= bound
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
iter = lhs_expr;
}
} else {
// iter < bound
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
iter = lhs_expr;
if (bound_at_left) {
// bound < iter / bound <= iter
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
iter = rhs_expr;
} else {
// iter < bound / iter <= bound
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
iter = lhs_expr;
}
}
// If it is a predicate for a single input iter
if (const auto* var_ptr = iter.as<VarNode>()) {
auto it = input_iters->find(GetRef<Var>(var_ptr));
if (it != input_iters->end()) {
PrimExpr iter_min = (*it).second->min;
PrimExpr iter_max = (*it).second->min + (*it).second->extent;
if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
input_iters->Set(GetRef<Var>(var_ptr), Range(iter_min, iter_max));
}
} else {
result->emplace_back(iter, lower_bound, upper_bound, 0);
}
}
result.emplace_back(iter, lower_bound, upper_bound, 0);
if (is_finish) {
break;
}
pred = rest.Eval();
}
return result;
return true;
}

bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
Expand All @@ -912,13 +931,14 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
if (!IterRangeSanityCheck(input_iters)) return Array<IterSumExpr>();
std::vector<IterConstraint> constraints = MatchBoundConstraints(predicate, input_iters);
if (!is_one(predicate) && constraints.empty()) {
Map<Var, Range> constrained_input_iters = input_iters;
std::vector<IterConstraint> constraints;
if (!is_one(predicate) &&
!MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) {
diag_ctx.Emit(Diagnostic::Error(predicate->span)
<< "Fail to collect constraints from iteration predicate: " << predicate);
return Array<IterSumExpr>();
}

// We have to make sure when we visit an iterator, all the constraints related with its successors
// in the iter var graph has been visited, where the expression of this iterator will contain the
// expression of its successor, so we sort them by their sizes.
Expand All @@ -930,10 +950,11 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, input_iters, diag_ctx);
IterMapRewriter rewriter(analyzer, constrained_input_iters, diag_ctx);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound);
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
constraint.upper_bound);
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
}
if (!rewriter.CheckConstraints()) {
Expand All @@ -945,7 +966,10 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
Array<IterSumExpr> results;
for (PrimExpr value : indices) {
results.push_back(rewriter.Rewrite(value));
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
if (rewriter.unresolved_count() != 0) {
diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Affine mapping detection failed");
return Array<IterSumExpr>();
}
}
// Step1: IterIndependenceChecker checks if the iterator are independent.
if (!rewriter.CheckMapping(results, require_bijective)) {
Expand Down Expand Up @@ -1306,7 +1330,8 @@ class IterMapToExprNormalizer : public ExprMutator {
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
return floordiv(source, expr->lower_factor) * expr->scale;
} else {
return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale;
return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) *
expr->scale;
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
return Everything();
}

Entry VisitExpr_(const FloorModNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
int64_t c2 = b.base;
ICHECK(c2 != 0) << "MathError: the divisor is 0";
Entry a = VisitExpr(op->a);
int64_t coeff = ZeroAwareGCD(a.coeff, c2);
return Entry(coeff, a.base % c2);
}
return Everything();
}

Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Expand Down
12 changes: 12 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x);
// floor div
TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x);
TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
c2.Eval()->value > 0);

// canonicalization rule
// will try rewrite again after canonicalization.
Expand Down Expand Up @@ -785,6 +787,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));

TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand All @@ -794,6 +801,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));

TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

Expand Down
16 changes: 10 additions & 6 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
for (int i = 0; i < n; i++) {
const PrimExpr& factor = factors[i];
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
substitute_value = substitute_value * factor + var;
if (!is_one(factor)) substitute_value = substitute_value * factor + var;
analyzer.Bind(var, Range::FromMinExtent(0, factor));
new_loop_vars.emplace_back(std::move(var));
}
Expand Down Expand Up @@ -505,11 +505,14 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix);
Array<PrimExpr> substitute_value;
substitute_value.resize(loops.size());
PrimExpr tot = fused_var;
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
substitute_value.Set(i, floormod(tot, loops[i]->extent));
tot = floordiv(tot, loops[i]->extent);
}
PrimExpr lower = 1;
for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
substitute_value.Set(i, is_one(loops[i]->extent)
? 0
: floordiv(floormod(fused_var, lower * loops[i]->extent), lower));
lower = lower * loops[i]->extent;
}
substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower));
Stmt new_stmt = loops.back()->body;
Map<Block, Block> opaque_block_reuse;
auto f_substitute = [&](const Var& v) -> Optional<PrimExpr> {
Expand All @@ -534,6 +537,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
return self->stmt2ref.at(new_stmt.get());
}

/*!
* \brief Collect an array of loop srefs into a set
* \param self The schedule state
Expand Down
5 changes: 4 additions & 1 deletion tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def test_mod():
ck.verify(
flm(y, 8),
{y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)},
(x * 4 - 8 * fld(x * 4, 8), x * 4 - 8 * fld(x * 4, 8) + 3),
(
z * 8 + x * 4 - 8 * fld(z * 8 + x * 4, 8),
z * 8 + x * 4 + 3 - 8 * fld(z * 8 + x * 4, 8),
),
)
ck1 = IntSetChecker()
ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2))
Expand Down
Loading

0 comments on commit 7127296

Please sign in to comment.