Skip to content

Commit

Permalink
[Arith] Fix iter_affine_map with non-const extent (apache#7437)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and trevor-m committed Mar 2, 2021
1 parent fbbd3f7 commit 10db633
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ class IterMapRewriter : public ExprMutator {
return analyzer_->CanProve(floormod(lhs, rhs) == 0);
}

PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr& orig);
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs, const PrimExpr& orig);

static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
tir::ExprDeepEqual equal;
Expand Down Expand Up @@ -584,7 +584,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
// cannot multiply two iterators, mark as unresolved.
++unresolved_count_;
return Mul(a, b);
return GetRef<PrimExpr>(op);
}

if (!a->IsInstance<IterMapExprNode>()) {
Expand All @@ -603,7 +603,8 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
}
}

PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs,
const PrimExpr& orig) {
// floordiv(x*scale, rhs)
if (is_one(rhs)) return std::move(lhs);
if (!is_one(lhs->scale)) {
Expand All @@ -619,7 +620,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
return floordiv(lhs, rhs);
return orig;
}
}
}
Expand All @@ -641,7 +642,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
return floordiv(lhs, rhs);
return orig;
}
}

Expand Down Expand Up @@ -669,25 +670,26 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
if (b->IsInstance<IterMapExprNode>()) {
// cannot divide an iterator, mark as unresolved.
++unresolved_count_;
return FloorDiv(a, b);
return GetRef<PrimExpr>(op);
}

if (a->IsInstance<IterSumExprNode>()) {
IterSumExpr ret = Downcast<IterSumExpr>(a);
if (auto opt = TryFuseIters(ret)) {
return SplitFloorDivConst(opt.value(), b);
return SplitFloorDivConst(opt.value(), b, GetRef<PrimExpr>(op));
} else {
++unresolved_count_;
return FloorDiv(a, b);
return GetRef<PrimExpr>(op);
}
} else {
ICHECK(a->IsInstance<IterSplitExprNode>());
IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
return SplitFloorDivConst(ret, b);
return SplitFloorDivConst(ret, b, GetRef<PrimExpr>(op));
}
}

PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs,
const PrimExpr& orig) {
// floormod(x*scale, rhs)
if (is_one(rhs)) return make_zero(lhs->dtype);
if (!is_one(lhs->scale)) {
Expand All @@ -701,7 +703,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
return floormod(lhs, rhs);
return orig;
}
}
}
Expand All @@ -715,7 +717,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
} else {
// mark as unresolved.
++unresolved_count_;
return floormod(lhs, rhs);
return orig;
}
}

Expand Down Expand Up @@ -743,21 +745,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
if (b->IsInstance<IterMapExprNode>()) {
// cannot mod an iterator, mark as unresolved.
++unresolved_count_;
return FloorMod(a, b);
return GetRef<PrimExpr>(op);
}

if (a->IsInstance<IterSumExprNode>()) {
IterSumExpr ret = Downcast<IterSumExpr>(a);
if (auto opt = TryFuseIters(ret)) {
return SplitFloorModConst(opt.value(), b);
return SplitFloorModConst(opt.value(), b, GetRef<PrimExpr>(op));
} else {
++unresolved_count_;
return FloorMod(a, b);
return GetRef<PrimExpr>(op);
}
} else {
ICHECK(a->IsInstance<IterSplitExprNode>());
IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
return SplitFloorModConst(ret, b);
return SplitFloorModConst(ret, b, GetRef<PrimExpr>(op));
}
}

Expand Down
3 changes: 3 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def test_split():
assert len(res) == 1
assert_iter_sum_pattern(res[0], 8, 0, scale=2)

res = tvm.arith.detect_iter_map([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)]))
assert len(res) == 0


def test_compound():
x = tvm.tir.Var("x", "int32"), 10
Expand Down

0 comments on commit 10db633

Please sign in to comment.