Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. #105566

Merged
merged 1 commit into from
Aug 23, 2024

Conversation

…lementation.

stack-info: PR: #105566, branch: users/PeimingLiu/stack/2
@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Stacked PRs:

  • #105567
  • ->#105566
  • #105565

[mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation.


Full diff: https://github.com/llvm/llvm-project/pull/105566.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+36-82)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e457..f7fcabb0220b5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
     std::unique_ptr<SparseIterator> it =
         iterSpace.extractIterator(rewriter, loc);
 
-    if (it->iteratableByFor()) {
-      auto [lo, hi] = it->genForCond(rewriter, loc);
-      Value step = constantIndex(rewriter, loc, 1);
-      SmallVector<Value> ivs;
-      for (ValueRange inits : adaptor.getInitArgs())
-        llvm::append_range(ivs, inits);
-      scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
-
-      Block *loopBody = op.getBody();
-      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-      if (failed(typeConverter->convertSignatureArgs(
-              loopBody->getArgumentTypes(), bodyTypeMapping)))
-        return failure();
-      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-      rewriter.eraseBlock(forOp.getBody());
-      Region &dstRegion = forOp.getRegion();
-      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
-      auto yieldOp =
-          llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
-
-      rewriter.setInsertionPointToEnd(forOp.getBody());
-      // replace sparse_tensor.yield with scf.yield.
-      rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
-      rewriter.eraseOp(yieldOp);
-
-      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-      rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-    } else {
-      SmallVector<Value> ivs;
-      // TODO: put iterator at the end of argument list to be consistent with
-      // coiterate operation.
-      llvm::append_range(ivs, it->getCursor());
-      for (ValueRange inits : adaptor.getInitArgs())
-        llvm::append_range(ivs, inits);
-
-      assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
-      TypeRange types = ValueRange(ivs).getTypes();
-      auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
-      SmallVector<Location> l(types.size(), op.getIterator().getLoc());
-
-      // Generates loop conditions.
-      Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
-      rewriter.setInsertionPointToStart(before);
-      ValueRange bArgs = before->getArguments();
-      auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
-      assert(remArgs.size() == adaptor.getInitArgs().size());
-      rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
-      // Generates loop body.
-      Block *loopBody = op.getBody();
-      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-      if (failed(typeConverter->convertSignatureArgs(
-              loopBody->getArgumentTypes(), bodyTypeMapping)))
-        return failure();
-      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-      Region &dstRegion = whileOp.getAfter();
-      // TODO: handle uses of coordinate!
-      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-      ValueRange aArgs = whileOp.getAfterArguments();
-      auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
-          whileOp.getAfterBody()->getTerminator());
-
-      rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+    SmallVector<Value> ivs;
+    for (ValueRange inits : adaptor.getInitArgs())
+      llvm::append_range(ivs, inits);
+
+    // Type conversion on iterate op block.
+    OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+    if (failed(typeConverter->convertSignatureArgs(
+            op.getBody()->getArgumentTypes(), blockTypeMapping)))
+      return rewriter.notifyMatchFailure(
+          op, "failed to convert iterate region argurment types");
+    rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+    Block *block = op.getBody();
+    ValueRange ret = genLoopWithIterator(
+        rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+        [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+                SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
+          SmallVector<Value> blockArgs(it->getCursor());
+          // TODO: Also appends coordinates if used.
+          // blockArgs.push_back(it->deref(rewriter, loc));
+          llvm::append_range(blockArgs, reduc);
+
+          Block *dstBlock = &loopBody.getBlocks().front();
+          rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
+                                     blockArgs);
+          auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+          // We can not use ValueRange as the operation holding the values will
+          // be destoryed.
+          SmallVector<Value> result(yield.getResults());
+          rewriter.eraseOp(yield);
+          return result;
+        });
 
-      aArgs = it->linkNewScope(aArgs);
-      ValueRange nx = it->forward(rewriter, loc);
-      SmallVector<Value> yields;
-      llvm::append_range(yields, nx);
-      llvm::append_range(yields, yieldOp.getResults());
-
-      // replace sparse_tensor.yield with scf.yield.
-      rewriter.eraseOp(yieldOp);
-      rewriter.create<scf::YieldOp>(loc, yields);
-      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-      rewriter.replaceOp(
-          op, whileOp.getResults().drop_front(it->getCursor().size()),
-          resultMapping);
-    }
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+    rewriter.replaceOp(op, ret, resultMapping);
     return success();
   }
 };
@@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
       Block *block = &region.getBlocks().front();
       OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
       if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
-                                                     blockTypeMapping)))
+                                                     blockTypeMapping))) {
         return rewriter.notifyMatchFailure(
             op, "failed to convert coiterate region argurment types");
+      }
 
       rewriter.applySignatureConversion(block, blockTypeMapping);
     }

kraj pushed a commit to kraj/llvm-project that referenced this pull request Aug 21, 2024
…lementation.

stack-info: PR: llvm#105566, branch: users/PeimingLiu/stack/2
@PeimingLiu PeimingLiu changed the base branch from users/PeimingLiu/stack/1 to main August 22, 2024 23:30
@PeimingLiu PeimingLiu force-pushed the users/PeimingLiu/stack/2 branch from 937bcd8 to 5d73f23 Compare August 22, 2024 23:30
@PeimingLiu PeimingLiu changed the base branch from main to users/PeimingLiu/stack/1 August 22, 2024 23:30
jollaitbot pushed a commit to sailfishos-mirror/llvm-project that referenced this pull request Aug 23, 2024
…lementation.

stack-info: PR: llvm/llvm-project#105566, branch: users/PeimingLiu/stack/2
Base automatically changed from users/PeimingLiu/stack/1 to main August 23, 2024 17:47
@PeimingLiu PeimingLiu force-pushed the users/PeimingLiu/stack/2 branch from 5d73f23 to 984d8d5 Compare August 23, 2024 17:47
@PeimingLiu PeimingLiu merged commit 7186704 into main Aug 23, 2024
8 checks passed
@PeimingLiu PeimingLiu deleted the users/PeimingLiu/stack/2 branch August 23, 2024 18:21
5chmidti pushed a commit that referenced this pull request Aug 24, 2024
dmpolukhin pushed a commit to dmpolukhin/llvm-project that referenced this pull request Sep 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants