From 4ad10157a8a4206f8fc993e9fb3e90831893e31f Mon Sep 17 00:00:00 2001 From: quic_rutkoor Date: Mon, 11 Dec 2023 21:34:33 -0800 Subject: [PATCH] Simplify nested if_then_else when constant is appearing in then_expr --- src/arith/ir_mutator_with_analyzer.cc | 5 +++-- .../tir-transform/test_tir_transform_simplify.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 2ee427beb86c..d26ac3667620 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -173,8 +173,9 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); }); } { - With constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond))); - false_value = this->VisitExpr(op->args[2]); + PrimExpr not_cond = Not(cond); + With constraint(analyzer_, not_cond); + WithRecordIterPredicate(not_cond, [&] { false_value = this->VisitExpr(op->args[2]); }); } if (is_zero(cond)) { return false_value; diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tir-transform/test_tir_transform_simplify.py index c779d92f9c47..6bad817c4955 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tir-transform/test_tir_transform_simplify.py @@ -1757,5 +1757,17 @@ def expected(a: T.handle): A[T.int64(1)] = T.float32(0) +class TestNestedIfElimination(BaseBeforeAfter): + def before(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")): + for i0, j0 in T.grid(2, 8): + b[i0, j0] = T.if_then_else( + i0 == 1 and 6 <= j0, 0, T.max(0, T.if_then_else(i0 == 1 and 6 <= j0, 0, a[i0, j0])) + ) + + def expected(a: T.Buffer((2, 8), "int32"), b: T.Buffer((2, 8), "int32")): + for i0, j0 in T.grid(2, 8): + b[i0, j0] = T.if_then_else(i0 == 1 and 6 <= j0, 0, T.max(0, a[i0, j0])) + + if __name__ == "__main__": tvm.testing.main()