Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][scf] scf.while uplifting: optimize op matching #88813

Merged
merged 2 commits into from
Apr 16, 2024

Conversation

Hardcode84
Copy link
Contributor

Instead of iterating over potential induction var uses looking for suitable arith.addi, try to trace it back from yield argument.

Instead of iterating over potential induction var uses looking for suitable `arith.addi`, try to trace it back from yield argument.
@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2024

@llvm/pr-subscribers-mlir-scf

Author: Ivan Butygin (Hardcode84)

Changes

Instead of iterating over potential induction var uses looking for suitable arith.addi, try to trace it back from yield argument.


Full diff: https://github.com/llvm/llvm-project/pull/88813.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+14-22)
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index fea2f659535bb4..7b4024b6861a72 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -101,38 +101,30 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
 
   Block *afterBody = loop.getAfterBody();
   scf::YieldOp afterTerm = loop.getYieldOp();
-  auto argNumber = inductionVar.getArgNumber();
-  auto afterTermIndArg = afterTerm.getResults()[argNumber];
+  unsigned argNumber = inductionVar.getArgNumber();
+  Value afterTermIndArg = afterTerm.getResults()[argNumber];
 
-  auto inductionVarAfter = afterBody->getArgument(argNumber);
-
-  Value step;
+  Value inductionVarAfter = afterBody->getArgument(argNumber);
 
   // Find suitable `addi` op inside `after` block, one of the args must be an
   // Induction var passed from `before` block and second arg must be defined
   // outside of the loop and will be considered step value.
   // TODO: Add `subi` support?
-  for (auto &use : inductionVarAfter.getUses()) {
-    auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
-    if (!owner)
-      continue;
-
-    auto other =
-        (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
-    if (!dom.properlyDominates(other, loop))
-      continue;
-
-    if (afterTermIndArg != owner.getResult())
-      continue;
+  auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
+  if (!addOp)
+    return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
 
-    step = other;
-    break;
+  Value step;
+  if (addOp.getLhs() == inductionVarAfter) {
+    step = addOp.getRhs();
+  } else if (addOp.getRhs() == inductionVarAfter) {
+    step = addOp.getLhs();
   }
 
-  if (!step)
-    return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
+  if (!step || !dom.properlyDominates(step, loop))
+    return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
 
-  auto lb = loop.getInits()[argNumber];
+  Value lb = loop.getInits()[argNumber];
 
   assert(lb.getType().isIntOrIndex());
   assert(lb.getType() == ub.getType());

@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2024

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Instead of iterating over potential induction var uses looking for suitable arith.addi, try to trace it back from yield argument.


Full diff: https://github.com/llvm/llvm-project/pull/88813.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+14-22)
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index fea2f659535bb4..7b4024b6861a72 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -101,38 +101,30 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
 
   Block *afterBody = loop.getAfterBody();
   scf::YieldOp afterTerm = loop.getYieldOp();
-  auto argNumber = inductionVar.getArgNumber();
-  auto afterTermIndArg = afterTerm.getResults()[argNumber];
+  unsigned argNumber = inductionVar.getArgNumber();
+  Value afterTermIndArg = afterTerm.getResults()[argNumber];
 
-  auto inductionVarAfter = afterBody->getArgument(argNumber);
-
-  Value step;
+  Value inductionVarAfter = afterBody->getArgument(argNumber);
 
   // Find suitable `addi` op inside `after` block, one of the args must be an
   // Induction var passed from `before` block and second arg must be defined
   // outside of the loop and will be considered step value.
   // TODO: Add `subi` support?
-  for (auto &use : inductionVarAfter.getUses()) {
-    auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
-    if (!owner)
-      continue;
-
-    auto other =
-        (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
-    if (!dom.properlyDominates(other, loop))
-      continue;
-
-    if (afterTermIndArg != owner.getResult())
-      continue;
+  auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
+  if (!addOp)
+    return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
 
-    step = other;
-    break;
+  Value step;
+  if (addOp.getLhs() == inductionVarAfter) {
+    step = addOp.getRhs();
+  } else if (addOp.getRhs() == inductionVarAfter) {
+    step = addOp.getLhs();
   }
 
-  if (!step)
-    return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
+  if (!step || !dom.properlyDominates(step, loop))
+    return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
 
-  auto lb = loop.getInits()[argNumber];
+  Value lb = loop.getInits()[argNumber];
 
   assert(lb.getType().isIntOrIndex());
   assert(lb.getType() == ub.getType());

@Hardcode84 Hardcode84 merged commit 1ca6b44 into llvm:main Apr 16, 2024
7 checks passed
@Hardcode84 Hardcode84 deleted the scf-uplift-optimize branch April 16, 2024 09:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants