Skip to content

Commit

Permalink
[ARITH] Simplify nested if_then_else when constant is appearing in th…
Browse files Browse the repository at this point in the history
…en_expr (#16227)

Simplify nested if_then_else when constant is appearing in then_expr
  • Loading branch information
rutkoor authored Dec 17, 2023
1 parent 870246a commit 799e810
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); });
}
{
With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond)));
false_value = this->VisitExpr(op->args[2]);
PrimExpr not_cond = Not(cond);
With<ConstraintContext> constraint(analyzer_, not_cond);
WithRecordIterPredicate(not_cond, [&] { false_value = this->VisitExpr(op->args[2]); });
}
if (is_zero(cond)) {
return false_value;
Expand Down
12 changes: 12 additions & 0 deletions tests/python/tir-transform/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,5 +1757,17 @@ def expected(a: T.handle):
A[T.int64(1)] = T.float32(0)


class TestNestedIfElimination(BaseBeforeAfter):
def before(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
for i0, j0 in T.grid(2, 8):
b[i0, j0] = T.if_then_else(
i0 == 1 and 6 <= j0, 0, T.max(0, T.if_then_else(i0 == 1 and 6 <= j0, 0, a[i0, j0]))
)

def expected(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")):
for i0, j0 in T.grid(2, 8):
b[i0, j0] = T.if_then_else(i0 == 1 and 6 <= j0, 0, T.max(0, a[i0, j0]))


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 799e810

Please sign in to comment.