Skip to content

Commit

Permalink
[TIR] Simplify chains of AndNode, chains of OrNode
Browse files Browse the repository at this point in the history
Previously, application of pairwise simplification rules applied to a
nested call to `And(And(a, b), c)` would attempt to simplify `a` and
`b`, but would not attempt to simplify `a` and `c`.  As a result,
while `(i==j && i!=j)` would simplify to `false`, `(i==j && j==5) &&
i!=j` would not simplify.

After this PR, all conditions that contribute to an `And` are checked
for pairwise simplifications.
  • Loading branch information
Lunderberg committed Sep 15, 2022
1 parent 7c531c8 commit 6d9216a
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 86 deletions.
28 changes: 0 additions & 28 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,33 +230,5 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const ReduceNode* op) {
return StmtExprMutator::VisitExpr_(op);
}

PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const AndNode* op) {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b;
{
With<ConstraintContext> constraint(analyzer_, a);
b = this->VisitExpr(op->b);
}
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
return And(a, b);
}
}

PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const OrNode* op) {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b;
{
With<ConstraintContext> constraint(analyzer_, Not(a));
b = this->VisitExpr(op->b);
}
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
return Or(a, b);
}
}

} // namespace arith
} // namespace tvm
3 changes: 0 additions & 3 deletions src/arith/ir_mutator_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
PrimExpr VisitExpr_(const tir::CallNode* op) override;
PrimExpr VisitExpr_(const tir::ReduceNode* op) override;

PrimExpr VisitExpr_(const tir::AndNode* op) override;
PrimExpr VisitExpr_(const tir::OrNode* op) override;

protected:
/*! \brief internal analyzer field. */
Analyzer* analyzer_;
Expand Down
243 changes: 188 additions & 55 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1516,82 +1516,215 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<AndNode>();
if (auto const_res = TryConstFold<And>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
std::vector<PrimExpr> subexprs = ExtractConstraints(GetRef<PrimExpr>(op), false);
ICHECK_GE(subexprs.size(), 2);

// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;
bool modified = false;

if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes));
// Simplify each of the subexpressions under the assumption that all
// other subexpressions are true.
for (size_t i = 0; i < subexprs.size(); i++) {
std::vector<With<ConstraintContext>> context;
context.reserve(subexprs.size() - 1);

PrimExpr remainder = Bool(true);
for (size_t j = 0; j < subexprs.size(); j++) {
if (i != j) {
context.emplace_back(analyzer_, subexprs[j]);
}
}

PrimExpr simplified = VisitExpr(subexprs[i]);
if (!simplified.same_as(subexprs[i])) {
modified = true;
}
subexprs[i] = simplified;

while (context.size()) {
context.pop_back();
}
}

auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));
TVM_TRY_REWRITE(x == y && x != y, cfalse);
TVM_TRY_REWRITE(x != y && x == y, cfalse);
TVM_TRY_REWRITE(x && !x, cfalse);
TVM_TRY_REWRITE(x <= y && y < x, cfalse);
TVM_TRY_REWRITE(y < x && x <= y, cfalse);
TVM_TRY_REWRITE(x <= y && y <= x, x == y);
// Rules to simplify a pair of conditions. Returns NullOpt if no
// simplification is possible.
auto pairwise_simplify = [this](const PrimExpr& a, const PrimExpr& b) -> Optional<PrimExpr> {
if (auto const_res = TryConstFold<And>(a, b)) return const_res.value();

TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
PrimExpr ret = a && b;

TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
if (auto match = TryMatchLiteralConstraint(ret)) {
return match.value();
}

TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value);
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;

if (ret->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes));
}

auto cfalse = PConst<PrimExpr>(make_const(ret->dtype, false));
TVM_TRY_REWRITE(x == y && x != y, cfalse);
TVM_TRY_REWRITE(x != y && x == y, cfalse);
TVM_TRY_REWRITE(x && !x, cfalse);
TVM_TRY_REWRITE(x <= y && y < x, cfalse);
TVM_TRY_REWRITE(y < x && x <= y, cfalse);
TVM_TRY_REWRITE(x <= y && y <= x, x == y);

TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);

TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);
TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value);

TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);

return NullOpt;
};

// Check each pairwise set of subexpressions for simplifications
for (size_t i = 0; i < subexprs.size(); i++) {
for (size_t j = i + 1; j < subexprs.size(); j++) {
PrimExpr& a = subexprs[i];
PrimExpr& b = subexprs[j];
if (a.defined() && b.defined()) {
if (Optional<PrimExpr> pairwise = pairwise_simplify(a, b)) {
a = PrimExpr();
b = pairwise.value();
modified = true;
}
}
}
}

if (!modified) {
return GetRef<PrimExpr>(op);
}

// Merge all remaining subexpressions
PrimExpr ret = Bool(true);
for (const auto& subexpr : subexprs) {
if (subexpr.defined()) {
ret = ret && subexpr;
}
}
return ret;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<OrNode>();
if (auto const_res = TryConstFold<Or>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
std::vector<PrimExpr> subexprs = ExtractComponents(GetRef<PrimExpr>(op));
ICHECK_GE(subexprs.size(), 2);

// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;
bool modified = false;

if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes));
// Simplify each of the subexpressions under the assumption that all
// other subexpressions are false.
for (size_t i = 0; i < subexprs.size(); i++) {
std::vector<With<ConstraintContext>> context;
context.reserve(subexprs.size() - 1);

PrimExpr remainder = Bool(true);
for (size_t j = 0; j < subexprs.size(); j++) {
if (i != j) {
context.emplace_back(analyzer_, RewriteBooleanOperators(Not(subexprs[j])));
}
}

PrimExpr simplified = VisitExpr(subexprs[i]);
if (!simplified.same_as(subexprs[i])) {
modified = true;
}

subexprs[i] = simplified;

while (context.size()) {
context.pop_back();
}
}

auto ctrue = PConst<PrimExpr>(make_const(op->dtype, true));
// Rules to simplify a pair of conditions. Returns NullOpt if no
// simplification is possible.
auto pairwise_simplify = [this](const PrimExpr& a, const PrimExpr& b) -> Optional<PrimExpr> {
if (auto const_res = TryConstFold<Or>(a, b)) return const_res.value();

TVM_TRY_REWRITE(x == y || x != y, ctrue);
TVM_TRY_REWRITE(x != y || x == y, ctrue);
TVM_TRY_REWRITE(x || !x, ctrue);
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || x <= y, ctrue);
PrimExpr ret = a || b;

TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;

TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
if (ret->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes));
}

TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
auto ctrue = PConst<PrimExpr>(make_const(ret->dtype, true));

TVM_TRY_REWRITE(x == y || x != y, ctrue);
TVM_TRY_REWRITE(x != y || x == y, ctrue);
TVM_TRY_REWRITE(x || !x, ctrue);
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || x <= y, ctrue);
TVM_TRY_REWRITE(x <= y || y <= x, ctrue);
TVM_TRY_REWRITE(x <= y || x == y, x <= y);
TVM_TRY_REWRITE(x == y || x <= y, x <= y);
TVM_TRY_REWRITE(x < y || x == y, x <= y);
TVM_TRY_REWRITE(x == y || x <= y, x <= y);

TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);

TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
return NullOpt;
};

// Check each pairwise set of subexpressions for simplifications
for (size_t i = 0; i < subexprs.size(); i++) {
for (size_t j = i + 1; j < subexprs.size(); j++) {
PrimExpr& a = subexprs[i];
PrimExpr& b = subexprs[j];
if (a.defined() && b.defined()) {
if (Optional<PrimExpr> pairwise = pairwise_simplify(a, b)) {
a = PrimExpr();
b = pairwise.value();
modified = true;
}
}
}
}

if (!modified) {
return GetRef<PrimExpr>(op);
}

// Merge all remaining subexpressions
PrimExpr ret = Bool(false);
for (const auto& subexpr : subexprs) {
if (subexpr.defined()) {
ret = ret || subexpr;
}
}
return ret;
}

Expand Down
Loading

0 comments on commit 6d9216a

Please sign in to comment.