-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse] fold sparse convert into producer linalg op. #89999
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PeimingLiu
requested review from
aartbik,
yinying-lisa-li and
matthias-springer
as code owners
April 24, 2024 22:12
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/89999.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..550e28813b4e9b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -90,17 +90,20 @@ inline MemRefType getMemRefType(T &&t) {
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
/// Returns true iff MLIR operand has any sparse operand.
-inline bool hasAnySparseOperand(Operation *op) {
- return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
- return getSparseTensorEncoding(t) != nullptr;
+inline bool hasAnySparseType(TypeRange types) {
+ return llvm::any_of(types, [](Type type) {
+ return getSparseTensorEncoding(type) != nullptr;
});
}
+/// Returns true iff MLIR operand has any sparse operand.
+inline bool hasAnySparseOperand(Operation *op) {
+ return hasAnySparseType(op->getOperands().getTypes());
+}
+
/// Returns true iff MLIR operand has any sparse result.
inline bool hasAnySparseResult(Operation *op) {
- return llvm::any_of(op->getResults().getTypes(), [](Type t) {
- return getSparseTensorEncoding(t) != nullptr;
- });
+ return hasAnySparseType(op->getResults().getTypes());
}
/// Returns true iff MLIR operand has any sparse operand or result.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5a39dfc6207707..641dcc61d7d09c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat
}
};
+/// Rewriting rule that converts direct yield of zero with initial allocation.
+struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConvertOp op,
+ PatternRewriter &rewriter) const override {
+ auto producer = op.getSource().getDefiningOp<GenericOp>();
+ if (!producer || producer.getDpsInits().size() != 1 ||
+ !isMaterializing(producer.getDpsInitOperand(0), false) ||
+ !producer.getResult(0).hasOneUse()) {
+ return failure();
+ }
+ rewriter.modifyOpInPlace(producer, [&]() {
+ producer.getResult(0).setType(op.getResult().getType());
+ });
+
+ Operation *materializeOp =
+ producer.getDpsInitOperand(0)->get().getDefiningOp();
+
+ rewriter.modifyOpInPlace(materializeOp, [&]() {
+ materializeOp->getResult(0).setType(op.getResult().getType());
+ });
+
+ rewriter.replaceAllOpUsesWith(op, producer);
+ op->erase();
+
+ return success();
+ }
+};
+
/// Rewriting rule that converts direct yield of zero with initial allocation.
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
public:
@@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
- patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
- FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
- GenSemiRingSelect, PrintRewriter>(patterns.getContext());
+ patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
+ FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
+ GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+ patterns.getContext());
}
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index cd046b670d9a8e..0a9bb40b458d68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
}
+static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
+ Value sparseOut, ValueRange ivs, Value v) {
+ scf::IfOp condInsert =
+ builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
+ // True branch.
+ builder.setInsertionPointToStart(condInsert.thenBlock());
+ Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
+ builder.create<scf::YieldOp>(loc, res);
+ // False branch.
+ builder.setInsertionPointToStart(condInsert.elseBlock());
+ builder.create<scf::YieldOp>(loc, sparseOut);
+ // Value assignment.
+ builder.setInsertionPointAfter(condInsert);
+ return condInsert.getResult(0);
+}
+
/// Generates insertion code to implement dynamic tensor store.
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
Value rhs) {
@@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
// return updated chain
// else
// return unmodified chain
- scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
- loc, chain.getType(), env.getValidLexInsert(),
- /*else=*/true);
- // True branch.
- builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
- Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
- builder.create<scf::YieldOp>(loc, res);
- // False branch.
- builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
- builder.create<scf::YieldOp>(loc, chain);
- // Value assignment.
- builder.setInsertionPointAfter(ifValidLexInsert);
- env.updateInsertionChain(ifValidLexInsert.getResult(0));
+ Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
+ chain, ivs, rhs);
+ env.updateInsertionChain(out);
} else {
+ Value sparseOut;
+ if (!hasAnySparseType(env.op().getInputs().getTypes())) {
+ // This is an all-dense -> sparse kernel, test rhs != 0 before
+ // insertion.
+ Value nz = genIsNonzero(builder, loc, rhs);
+ sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
+ } else {
+ sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
+ }
// Generates regular insertion chain.
- env.updateInsertionChain(
- builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
+ env.updateInsertionChain(sparseOut);
}
return;
}
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
new file mode 100644
index 00000000000000..077dde230fd156
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s
+
+#trait = {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+}
+
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+// CHECK-LABEL: func.func @test(
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.if
+// CHECK-NEXT: tensor.insert
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: else
+// CHECK-NEXT: scf.yield
+// CHECK: scf.yield
+// CHECK: scf.yield
+// CHECK: scf.yield
+// CHECK: sparse_tensor.load
+func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant 1.000000e+00 : f32
+ %cst_1 = arith.constant 1.000000e+00 : f32
+ %0 = tensor.empty() : tensor<128x32x32x1xf32>
+ %1 = linalg.generic #trait
+ ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
+ outs(%0 : tensor<128x32x32x1xf32>) {
+ ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
+ %3 = arith.subf %cst_0, %in_2 : f32
+ %4 = arith.mulf %in, %3 : f32
+ %5 = arith.mulf %4, %cst_1 : f32
+ %6 = arith.addf %5, %in_3 : f32
+ %7 = arith.subf %6, %cst_0 : f32
+ %8 = arith.cmpf uge, %7, %cst : f32
+ %9 = arith.uitofp %8 : i1 to f32
+ linalg.yield %9 : f32
+ } -> tensor<128x32x32x1xf32>
+ %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse>
+ return %2 : tensor<128x32x32x1xf32, #sparse>
+}
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
index bbc7f397e793fe..f2f64567d5bd01 100644
--- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
@@ -24,7 +24,6 @@ module {
// CHECK: arith.constant
// CHECK: tensor.empty()
// CHECK: linalg.generic
- // CHECK: sparse_tensor.convert
// CHECK: return
//
func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
@@ -44,4 +43,3 @@ module {
return %cast : tensor<10x20x30xf64, #sparse>
}
}
-
|
PeimingLiu
changed the title
[mlir][sparse] fold sparse convert into producer generic op.
[mlir][sparse] fold sparse convert into producer linalg op.
Apr 24, 2024
aartbik
reviewed
Apr 24, 2024
yinying-lisa-li
approved these changes
Apr 24, 2024
PeimingLiu
force-pushed
the
fuse-convert
branch
3 times, most recently
from
April 25, 2024 00:07
ab8c15a
to
7f432f3
Compare
aartbik
approved these changes
Apr 26, 2024
PeimingLiu
force-pushed
the
fuse-convert
branch
from
April 26, 2024 16:32
7f432f3
to
50294d3
Compare
PeimingLiu
force-pushed
the
fuse-convert
branch
from
April 26, 2024 16:35
50294d3
to
c689fbb
Compare
This was referenced Apr 29, 2024
Closed
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.