Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] fold sparse convert into producer linalg op. #89999

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,21 @@ inline MemRefType getMemRefType(T &&t) {
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);

/// Returns true iff the type range has any sparse tensor type.
inline bool hasAnySparseType(TypeRange types) {
return llvm::any_of(types, [](Type type) {
return getSparseTensorEncoding(type) != nullptr;
});
}

/// Returns true iff MLIR operand has any sparse operand.
inline bool hasAnySparseOperand(Operation *op) {
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
return getSparseTensorEncoding(t) != nullptr;
});
return hasAnySparseType(op->getOperands().getTypes());
}

/// Returns true iff MLIR operand has any sparse result.
inline bool hasAnySparseResult(Operation *op) {
return llvm::any_of(op->getResults().getTypes(), [](Type t) {
return getSparseTensorEncoding(t) != nullptr;
});
return hasAnySparseType(op->getResults().getTypes());
}

/// Returns true iff MLIR operand has any sparse operand or result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,37 @@ struct FuseExtractSliceWithConcat
}
};

/// Rewriting rule that fuses sparse_tensor.convert into producer.
struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
auto producer = op.getSource().getDefiningOp<GenericOp>();
if (!producer || producer.getDpsInits().size() != 1 ||
!isMaterializing(producer.getDpsInitOperand(0), false) ||
!producer.getResult(0).hasOneUse()) {
return failure();
}
rewriter.modifyOpInPlace(producer, [&]() {
producer.getResult(0).setType(op.getResult().getType());
});

Operation *materializeOp =
producer.getDpsInitOperand(0)->get().getDefiningOp();

rewriter.modifyOpInPlace(materializeOp, [&]() {
materializeOp->getResult(0).setType(op.getResult().getType());
});

rewriter.replaceAllOpUsesWith(op, producer);
op->erase();

return success();
}
};

/// Rewriting rule that converts direct yield of zero with initial allocation.
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
public:
Expand Down Expand Up @@ -1506,9 +1537,10 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//

void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
GenSemiRingSelect, PrintRewriter>(patterns.getContext());
patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
patterns.getContext());
}

void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
Expand Down
44 changes: 29 additions & 15 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,22 @@ static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
}

static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
Value sparseOut, ValueRange ivs, Value v) {
scf::IfOp condInsert =
builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
// True branch.
builder.setInsertionPointToStart(condInsert.thenBlock());
Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
builder.create<scf::YieldOp>(loc, res);
// False branch.
builder.setInsertionPointToStart(condInsert.elseBlock());
builder.create<scf::YieldOp>(loc, sparseOut);
// Value assignment.
builder.setInsertionPointAfter(condInsert);
return condInsert.getResult(0);
}

/// Generates insertion code to implement dynamic tensor store.
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
Value rhs) {
Expand All @@ -423,23 +439,21 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
// return updated chain
// else
// return unmodified chain
scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
loc, chain.getType(), env.getValidLexInsert(),
/*else=*/true);
// True branch.
builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
builder.create<scf::YieldOp>(loc, res);
// False branch.
builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
builder.create<scf::YieldOp>(loc, chain);
// Value assignment.
builder.setInsertionPointAfter(ifValidLexInsert);
env.updateInsertionChain(ifValidLexInsert.getResult(0));
Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
chain, ivs, rhs);
env.updateInsertionChain(out);
} else {
Value sparseOut;
if (!hasAnySparseType(env.op().getInputs().getTypes())) {
// This is an all-dense -> sparse kernel, test rhs != 0 before
// insertion.
Value nz = genIsNonzero(builder, loc, rhs);
sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
} else {
sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
}
// Generates regular insertion chain.
env.updateInsertionChain(
builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
env.updateInsertionChain(sparseOut);
}
return;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map | FileCheck %s --check-prefix=CHECK-FOLD
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s

#trait = {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

#COO = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa))}>
#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>

// CHECK-LABEL: func.func @fold_convert(
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.if
// CHECK-NEXT: tensor.insert
// CHECK-NEXT: scf.yield
// CHECK-NEXT: else
// CHECK-NEXT: scf.yield
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: scf.yield
// CHECK: sparse_tensor.load

// CHECK-FOLD-LABEL: func.func @fold_convert(
// CHECK-FOLD-NOT: sparse_tensor.convert
func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #CCCD> {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%0 = tensor.empty() : tensor<128x32x32x1xf32>
%1 = linalg.generic #trait
ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
outs(%0 : tensor<128x32x32x1xf32>) {
^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
%3 = arith.subf %cst_0, %in_2 : f32
%4 = arith.mulf %in, %3 : f32
%5 = arith.mulf %4, %cst_1 : f32
%6 = arith.addf %5, %in_3 : f32
%7 = arith.subf %6, %cst_0 : f32
%8 = arith.cmpf uge, %7, %cst : f32
%9 = arith.uitofp %8 : i1 to f32
linalg.yield %9 : f32
} -> tensor<128x32x32x1xf32>
%2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
return %2 : tensor<128x32x32x1xf32, #CCCD>
}


// FIXME: The following kernel is not sparsifiable because `arith.select`
// operations is not handled by the sparse compiler at the moment.
//
// CHECK-FOLD-LABEL: func.func @fold_cast(
// CHECK-FOLD-NOT: sparse_tensor.convert
func.func @fold_cast(%0: tensor<10x20x30xf64, #COO>) -> tensor<10x20x30xf64, #COO> {
%cst = arith.constant 0.000000e+00 : f64
%1 = tensor.empty() : tensor<10x20x30xf64>
%2 = linalg.generic { indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel", "parallel"]
}
ins (%0 : tensor<10x20x30xf64, #COO>)
outs(%1 : tensor<10x20x30xf64>) {
^bb0(%in: f64, %out: f64):
%4 = arith.cmpf ugt, %in, %cst : f64
%5 = arith.select %4, %in, %cst : f64
linalg.yield %5 : f64
} -> tensor<10x20x30xf64>
%cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #COO>
return %cast : tensor<10x20x30xf64, #COO>
}
47 changes: 0 additions & 47 deletions mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir

This file was deleted.

Loading