From 4df0170c78c5f2114b3c7a4729f7fcbe9edc8446 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 15 Jul 2024 16:24:46 +0200 Subject: [PATCH] Fix simplification of x + x//c*-c to x mod c. There was no check that rhs is actually a multiplication. --- mlir/lib/IR/AffineExpr.cpp | 4 +++- mlir/unittests/IR/AffineExprTest.cpp | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index bfb7c4849356eb..75cc01ee9a146a 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -751,8 +751,10 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { } // Process lrhs, which is 'expr floordiv c'. + // expr + (expr // c * -c) = expr % c AffineBinaryOpExpr lrBinOpExpr = dyn_cast(lrhs); - if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) + if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul || + lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) return nullptr; llrhs = lrBinOpExpr.getLHS(); diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp index a0affc4341b0b4..75c893334943d3 100644 --- a/mlir/unittests/IR/AffineExprTest.cpp +++ b/mlir/unittests/IR/AffineExprTest.cpp @@ -98,3 +98,11 @@ TEST(AffineExprTest, divisionSimplification) { ASSERT_EQ((d0 * 6).ceilDiv(4).getKind(), AffineExprKind::CeilDiv); ASSERT_EQ((d0 * 6).ceilDiv(-2), d0 * -3); } + +TEST(AffineExprTest, modSimplificationRegression) { + MLIRContext ctx; + OpBuilder b(&ctx); + auto d0 = b.getAffineDimExpr(0); + auto sum = d0 + d0.floorDiv(3).floorDiv(-3); + ASSERT_EQ(sum.getKind(), AffineExprKind::Add); +}