From a905cb262686a347549d3eb811359046812232e6 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Wed, 24 Apr 2024 22:10:27 +0000 Subject: [PATCH 1/2] [mlir][sparse] fold sparse convert into producer generic operation. --- .../Dialect/SparseTensor/IR/SparseTensor.h | 15 +++--- .../Transforms/SparseTensorRewriting.cpp | 38 +++++++++++++-- .../Transforms/Sparsification.cpp | 44 +++++++++++------ .../fuse_sparse_convert_into_producer.mlir | 48 +++++++++++++++++++ .../SparseTensor/no_fold_into_consumer.mlir | 2 - 5 files changed, 121 insertions(+), 26 deletions(-) create mode 100644 mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index 5e523ec428aefb9..550e28813b4e9b1 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -90,17 +90,20 @@ inline MemRefType getMemRefType(T &&t) { SparseTensorEncodingAttr getSparseTensorEncoding(Type type); /// Returns true iff MLIR operand has any sparse operand. -inline bool hasAnySparseOperand(Operation *op) { - return llvm::any_of(op->getOperands().getTypes(), [](Type t) { - return getSparseTensorEncoding(t) != nullptr; +inline bool hasAnySparseType(TypeRange types) { + return llvm::any_of(types, [](Type type) { + return getSparseTensorEncoding(type) != nullptr; }); } +/// Returns true iff MLIR operand has any sparse operand. +inline bool hasAnySparseOperand(Operation *op) { + return hasAnySparseType(op->getOperands().getTypes()); +} + /// Returns true iff MLIR operand has any sparse result. inline bool hasAnySparseResult(Operation *op) { - return llvm::any_of(op->getResults().getTypes(), [](Type t) { - return getSparseTensorEncoding(t) != nullptr; - }); + return hasAnySparseType(op->getResults().getTypes()); } /// Returns true iff MLIR operand has any sparse operand or result. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 5a39dfc6207707f..641dcc61d7d09c7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat } }; +/// Rewriting rule that converts direct yield of zero with initial allocation. +struct FoldConvertIntoProducer : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConvertOp op, + PatternRewriter &rewriter) const override { + auto producer = op.getSource().getDefiningOp(); + if (!producer || producer.getDpsInits().size() != 1 || + !isMaterializing(producer.getDpsInitOperand(0), false) || + !producer.getResult(0).hasOneUse()) { + return failure(); + } + rewriter.modifyOpInPlace(producer, [&]() { + producer.getResult(0).setType(op.getResult().getType()); + }); + + Operation *materializeOp = + producer.getDpsInitOperand(0)->get().getDefiningOp(); + + rewriter.modifyOpInPlace(materializeOp, [&]() { + materializeOp->getResult(0).setType(op.getResult().getType()); + }); + + rewriter.replaceAllOpUsesWith(op, producer); + op->erase(); + + return success(); + } +}; + /// Rewriting rule that converts direct yield of zero with initial allocation. struct FoldInvariantYield : public OpRewritePattern { public: @@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern { //===---------------------------------------------------------------------===// void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index cd046b670d9a8e3..0a9bb40b458d68f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, return builder.create(loc, isFilled, valAtIndex, identity); } +static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond, + Value sparseOut, ValueRange ivs, Value v) { + scf::IfOp condInsert = + builder.create(loc, sparseOut.getType(), cond, true); + // True branch. + builder.setInsertionPointToStart(condInsert.thenBlock()); + Value res = builder.create(loc, v, sparseOut, ivs); + builder.create(loc, res); + // False branch. + builder.setInsertionPointToStart(condInsert.elseBlock()); + builder.create(loc, sparseOut); + // Value assignment. + builder.setInsertionPointAfter(condInsert); + return condInsert.getResult(0); +} + /// Generates insertion code to implement dynamic tensor store. static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs) { @@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, // return updated chain // else // return unmodified chain - scf::IfOp ifValidLexInsert = builder.create( - loc, chain.getType(), env.getValidLexInsert(), - /*else=*/true); - // True branch. - builder.setInsertionPointToStart(ifValidLexInsert.thenBlock()); - Value res = builder.create(loc, rhs, chain, ivs); - builder.create(loc, res); - // False branch. - builder.setInsertionPointToStart(ifValidLexInsert.elseBlock()); - builder.create(loc, chain); - // Value assignment. - builder.setInsertionPointAfter(ifValidLexInsert); - env.updateInsertionChain(ifValidLexInsert.getResult(0)); + Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(), + chain, ivs, rhs); + env.updateInsertionChain(out); } else { + Value sparseOut; + if (!hasAnySparseType(env.op().getInputs().getTypes())) { + // This is an all-dense -> sparse kernel, test rhs != 0 before + // insertion. + Value nz = genIsNonzero(builder, loc, rhs); + sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs); + } else { + sparseOut = builder.create(loc, rhs, chain, ivs); + } // Generates regular insertion chain. - env.updateInsertionChain( - builder.create(loc, rhs, chain, ivs)); + env.updateInsertionChain(sparseOut); } return; } diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir new file mode 100644 index 000000000000000..077dde230fd1563 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s + +#trait = { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] +} + +#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }> + +// CHECK-LABEL: func.func @test( +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.if +// CHECK-NEXT: tensor.insert +// CHECK-NEXT: scf.yield +// CHECK-NEXT: else +// CHECK-NEXT: scf.yield +// CHECK: scf.yield +// CHECK: scf.yield +// CHECK: scf.yield +// CHECK: sparse_tensor.load +func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %cst_1 = arith.constant 1.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x32x32x1xf32> + %1 = linalg.generic #trait + ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>) + outs(%0 : tensor<128x32x32x1xf32>) { + ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32): + %3 = arith.subf %cst_0, %in_2 : f32 + %4 = arith.mulf %in, %3 : f32 + %5 = arith.mulf %4, %cst_1 : f32 + %6 = arith.addf %5, %in_3 : f32 + %7 = arith.subf %6, %cst_0 : f32 + %8 = arith.cmpf uge, %7, %cst : f32 + %9 = arith.uitofp %8 : i1 to f32 + linalg.yield %9 : f32 + } -> tensor<128x32x32x1xf32> + %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse> + return %2 : tensor<128x32x32x1xf32, #sparse> +} diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir index bbc7f397e793fe1..f2f64567d5bd010 100644 --- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir +++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir @@ -24,7 +24,6 @@ module { // CHECK: arith.constant // CHECK: tensor.empty() // CHECK: linalg.generic - // CHECK: sparse_tensor.convert // CHECK: return // func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> { @@ -44,4 +43,3 @@ module { return %cast : tensor<10x20x30xf64, #sparse> } } - From c689fbbce4b6033f1ee6854fdfb3dd411303b16e Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Wed, 24 Apr 2024 23:25:10 +0000 Subject: [PATCH 2/2] address comments --- .../Dialect/SparseTensor/IR/SparseTensor.h | 2 +- .../Transforms/SparseTensorRewriting.cpp | 2 +- .../fuse_sparse_convert_into_producer.mlir | 40 ++++++++++++++--- .../SparseTensor/no_fold_into_consumer.mlir | 45 ------------------- 4 files changed, 37 insertions(+), 52 deletions(-) delete mode 100644 mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index 550e28813b4e9b1..b182b4c72b9535c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -89,7 +89,7 @@ inline MemRefType getMemRefType(T &&t) { /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); -/// Returns true iff MLIR operand has any sparse operand. +/// Returns true iff the type range has any sparse tensor type. inline bool hasAnySparseType(TypeRange types) { return llvm::any_of(types, [](Type type) { return getSparseTensorEncoding(type) != nullptr; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 641dcc61d7d09c7..9a8c6422a7ff623 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -289,7 +289,7 @@ struct FuseExtractSliceWithConcat } }; -/// Rewriting rule that converts direct yield of zero with initial allocation. +/// Rewriting rule that fuses sparse_tensor.convert into producer. struct FoldConvertIntoProducer : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir index 077dde230fd1563..efa92e565ba5759 100644 --- a/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir +++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir @@ -1,3 +1,4 @@ +// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map | FileCheck %s --check-prefix=CHECK-FOLD // RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s #trait = { @@ -10,9 +11,12 @@ iterator_types = ["parallel", "parallel", "parallel", "parallel"] } -#sparse = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }> +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func.func @test( +#COO = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa))}> +#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }> + +// CHECK-LABEL: func.func @fold_convert( // CHECK: scf.for // CHECK: scf.for // CHECK: scf.for @@ -25,7 +29,10 @@ // CHECK: scf.yield // CHECK: scf.yield // CHECK: sparse_tensor.load -func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #sparse> { + +// CHECK-FOLD-LABEL: func.func @fold_convert( +// CHECK-FOLD-NOT: sparse_tensor.convert +func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #CCCD> { %cst = arith.constant 0.000000e+00 : f32 %cst_0 = arith.constant 1.000000e+00 : f32 %cst_1 = arith.constant 1.000000e+00 : f32 @@ -43,6 +50,29 @@ func.func @test(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %9 = arith.uitofp %8 : i1 to f32 linalg.yield %9 : f32 } -> tensor<128x32x32x1xf32> - %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #sparse> - return %2 : tensor<128x32x32x1xf32, #sparse> + %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD> + return %2 : tensor<128x32x32x1xf32, #CCCD> +} + + +// FIXME: The following kernel is not sparsifiable because `arith.select` +// operations is not handled by the sparse compiler at the moment. +// +// CHECK-FOLD-LABEL: func.func @fold_cast( +// CHECK-FOLD-NOT: sparse_tensor.convert +func.func @fold_cast(%0: tensor<10x20x30xf64, #COO>) -> tensor<10x20x30xf64, #COO> { + %cst = arith.constant 0.000000e+00 : f64 + %1 = tensor.empty() : tensor<10x20x30xf64> + %2 = linalg.generic { indexing_maps = [#map, #map], + iterator_types = ["parallel", "parallel", "parallel"] + } + ins (%0 : tensor<10x20x30xf64, #COO>) + outs(%1 : tensor<10x20x30xf64>) { + ^bb0(%in: f64, %out: f64): + %4 = arith.cmpf ugt, %in, %cst : f64 + %5 = arith.select %4, %in, %cst : f64 + linalg.yield %5 : f64 + } -> tensor<10x20x30xf64> + %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #COO> + return %cast : tensor<10x20x30xf64, #COO> } diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir deleted file mode 100644 index f2f64567d5bd010..000000000000000 --- a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: mlir-opt %s --canonicalize --pre-sparsification-rewrite | FileCheck %s - -#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -#sparse = #sparse_tensor.encoding<{ - map = (d0, d1, d2) -> - (d0 : compressed(nonunique), - d1 : singleton(nonunique, soa), - d2 : singleton(soa)), - posWidth = 64, - crdWidth = 64 -}> - - -module { - // - // This IR should not end up in an infinite loop trying to fold - // the linalg producer into the tensor cast consumer (even though - // static sizes can fold, the different encodings cannot). The - // cast was sloppy to begin with (but it has been observed by - // external sources) and can be easily repaired by the sparsifier. - // - // CHECK-LABEL: func @avoid_fold - // CHECK: arith.constant - // CHECK: tensor.empty() - // CHECK: linalg.generic - // CHECK: return - // - func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> { - %1 = tensor.empty() : tensor<10x20x30xf64> - %2 = linalg.generic { indexing_maps = [#map, #map], - iterator_types = ["parallel", "parallel", "parallel"] - } - ins (%0 : tensor<10x20x30xf64, #sparse>) - outs(%1 : tensor<10x20x30xf64>) { - ^bb0(%in: f64, %out: f64): - %cst = arith.constant 0.000000e+00 : f64 - %4 = arith.cmpf ugt, %in, %cst : f64 - %5 = arith.select %4, %in, %cst : f64 - linalg.yield %5 : f64 - } -> tensor<10x20x30xf64> - %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse> - return %cast : tensor<10x20x30xf64, #sparse> - } -}