Skip to content

Commit

Permalink
[TIR] Simplify (x<=y && y<=x)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Sep 15, 2022
1 parent 0716129 commit 7c531c8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
TVM_TRY_REWRITE(x && !x, cfalse);
TVM_TRY_REWRITE(x <= y && y < x, cfalse);
TVM_TRY_REWRITE(y < x && x <= y, cfalse);
TVM_TRY_REWRITE(x <= y && y <= x, x == y);

TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.Simplify()


class TestSimplifyLENodeToEqualNode(BaseBeforeAfter):
"""If neither value is greater than each other, they are equal."""

def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
A[0] = i <= j and j <= i

def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
A[0] = i == j


class TestLoadStoreNoop(BaseBeforeAfter):
"""Store of a value that was just read from the same location is a no-op."""

Expand Down

0 comments on commit 7c531c8

Please sign in to comment.