Skip to content

Commit

Permalink
[mlir][sparse] implement loose-compressed/2:4 on direct IR codegen pa…
Browse files Browse the repository at this point in the history
…th (llvm#71461)

Fills in the missing cases for direct IR codegen.
Note that non-permutation handling is still TBD.
  • Loading branch information
aartbik committed Nov 7, 2023
1 parent 16a395b commit 160d483
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 84 deletions.
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,8 @@ void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &out) {
out.clear();
out.reserve(stt.getDimRank());
for (const Size sh : stt.getDimShape()) {
const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
for (const Size sz : stt.getDimShape()) {
const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
out.push_back(constantIndex(builder, loc, s));
}
}
Expand Down
156 changes: 78 additions & 78 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#include "CodegenUtils.h"
#include "SparseTensorDescriptor.h"

#include "llvm/Support/FormatVariadic.h"

#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -116,31 +114,36 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
const SparseTensorType stt(desc.getRankedTensorType());
Value linear = constantIndex(builder, loc, 1);
const Level lvlRank = stt.getLvlRank();
for (Level l = startLvl; l < lvlRank; l++) {
const auto dlt = stt.getLvlType(l);
if (isCompressedDLT(dlt)) {
for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
const auto dlt = stt.getLvlType(lvl);
if (isCompressedDLT(dlt) || isLooseCompressedDLT(dlt)) {
// Append linear x positions, initialized to zero. Since each compressed
// dimension initially already has a single zero entry, this maintains
// the desired "linear + 1" length property at all times.
// the desired "linear + 1" length property at all times. For loose
// compression, we multiply linear by two in order to append both the
// lo/hi positions.
Value posZero = constantZero(builder, loc, stt.getPosType());
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
posZero, linear);
if (isLooseCompressedDLT(dlt)) {
Value two = constantIndex(builder, loc, 2);
linear = builder.create<arith::MulIOp>(loc, linear, two);
}
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
/*value=*/posZero, /*repeat=*/linear);
return;
}
if (isSingletonDLT(dlt)) {
} else if (isSingletonDLT(dlt) || is2OutOf4DLT(dlt)) {
return; // nothing to do
}
// Keep compounding the size, but nothing needs to be initialized
// at this level. We will eventually reach a compressed level or
// otherwise the values array for the from-here "all-dense" case.
assert(isDenseDLT(dlt));
Value size = desc.getLvlSize(builder, loc, l);
Value size = desc.getLvlSize(builder, loc, lvl);
linear = builder.create<arith::MulIOp>(loc, linear, size);
}
// Reached values array so prepare for an insertion.
Value valZero = constantZero(builder, loc, stt.getElementType());
createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
std::nullopt, valZero, linear);
std::nullopt, /*value=*/valZero, /*repeat=*/linear);
}

/// Creates allocation operation.
Expand All @@ -157,12 +160,9 @@ static Value createAllocation(OpBuilder &builder, Location loc,
}

/// Creates allocation for each field in sparse tensor type. Note that
/// for all dynamic memrefs, the memory size is really the capacity of
/// the "vector", while the actual size resides in the sizes array.
///
/// TODO: for efficiency, we will need heuristics to make educated guesses
/// on the required capacities (see heuristic variable).
///
/// for all dynamic memrefs in the sparse tensor stroage layout, the
/// memory size is really the capacity of the "vector", while the actual
/// size resides in the sizes array.
static void createAllocFields(OpBuilder &builder, Location loc,
SparseTensorType stt, ValueRange dynSizes,
bool enableInit, SmallVectorImpl<Value> &fields,
Expand Down Expand Up @@ -206,6 +206,8 @@ static void createAllocFields(OpBuilder &builder, Location loc,
constantIndex(builder, loc, 16);
}

// Initializes all fields. An initial storage specifier and allocated
// positions/coordinates/values memrefs (with heuristic capacity).
foreachFieldAndTypeInSparseTensor(
stt,
[&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
Expand All @@ -218,14 +220,16 @@ static void createAllocFields(OpBuilder &builder, Location loc,
field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
break;
case SparseTensorFieldKind::PosMemRef:
field = createAllocation(builder, loc, cast<MemRefType>(fType),
posHeuristic, enableInit);
break;
case SparseTensorFieldKind::CrdMemRef:
field = createAllocation(builder, loc, cast<MemRefType>(fType),
crdHeuristic, enableInit);
break;
case SparseTensorFieldKind::ValMemRef:
field = createAllocation(
builder, loc, cast<MemRefType>(fType),
(fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic
: (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic
: valHeuristic,
enableInit);
field = createAllocation(builder, loc, cast<MemRefType>(fType),
valHeuristic, enableInit);
break;
}
assert(field);
Expand All @@ -234,21 +238,19 @@ static void createAllocFields(OpBuilder &builder, Location loc,
return true;
});

// Initialize the storage scheme to an empty tensor. Sets the lvlSizes
// and gives all position fields an initial zero entry, so that it is
// easier to maintain the "linear + 1" length property.
MutSparseTensorDescriptor desc(stt, fields);

// Initialize the storage scheme to an empty tensor. Initialized memSizes
// to all zeros, sets the dimSizes to known values and gives all position
// fields an initial zero entry, so that it is easier to maintain the
// "linear + 1" length property.
Value posZero = constantZero(builder, loc, stt.getPosType());
for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) {
// Fills dim sizes array.
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
// FIXME: `toOrigDim` is deprecated.
desc.setLvlSize(builder, loc, l, dimSizes[toOrigDim(stt.getEncoding(), l)]);
// Pushes a leading zero to positions memref.
if (stt.isCompressedLvl(l))
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
posZero);
desc.setLvlSize(builder, loc, lvl,
dimSizes[toOrigDim(stt.getEncoding(), lvl)]);
const auto dlt = stt.getLvlType(lvl);
if (isCompressedDLT(dlt) || isLooseCompressedDLT(dlt))
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
/*value=*/posZero);
}
allocSchemeForRank(builder, loc, desc, /*rank=*/0);
}
Expand Down Expand Up @@ -347,7 +349,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
genStore(builder, loc, mszp1, positionsAtLvl, pp1);
createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
lvlCoords[lvl]);
/*value=*/lvlCoords[lvl]);
// Prepare the next level "as needed".
if ((lvl + 1) < lvlRank)
allocSchemeForRank(builder, loc, desc, lvl + 1);
Expand All @@ -371,8 +373,6 @@ static void genEndInsert(OpBuilder &builder, Location loc,
const Level lvlRank = stt.getLvlRank();
for (Level l = 0; l < lvlRank; l++) {
const auto dlt = stt.getLvlType(l);
if (isLooseCompressedDLT(dlt))
llvm_unreachable("TODO: Not yet implemented");
if (isCompressedDLT(dlt)) {
// Compressed dimensions need a position cleanup for all entries
// that were not visited during the insertion pass.
Expand Down Expand Up @@ -407,7 +407,8 @@ static void genEndInsert(OpBuilder &builder, Location loc,
builder.setInsertionPointAfter(loop);
}
} else {
assert(isDenseDLT(dlt) || isSingletonDLT(dlt));
assert(isDenseDLT(dlt) || isLooseCompressedDLT(dlt) ||
isSingletonDLT(dlt) || is2OutOf4DLT(dlt));
}
}
}
Expand Down Expand Up @@ -483,33 +484,37 @@ class SparseInsertGenerator
Value value = args.back();
Value parentPos = constantZero(builder, loc, builder.getIndexType());
// Generate code for every level.
for (Level l = 0; l < lvlRank; l++) {
const auto dlt = stt.getLvlType(l);
if (isCompressedDLT(dlt)) {
for (Level lvl = 0; lvl < lvlRank; lvl++) {
const auto dlt = stt.getLvlType(lvl);
if (isCompressedDLT(dlt) || isLooseCompressedDLT(dlt)) {
// Create:
// if (!present) {
// coordinates[l].push_back(coords[l])
// <update positions and prepare level l + 1>
// coordinates[lvl].push_back(coords[lvl])
// <update positions and prepare level lvl + 1>
// }
// positions[l] = coordinates.size() - 1
// <insert @ positions[l] at next level l + 1>
// positions[lvl] = coordinates.size() - 1
// <insert @ positions[lvl] at next level lvl + 1>
if (isLooseCompressedDLT(dlt)) {
Value two = constantIndex(builder, loc, 2);
parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
}
parentPos =
genCompressed(builder, loc, desc, coords, value, parentPos, l);
} else if (isSingletonDLT(dlt)) {
genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
} else if (isSingletonDLT(dlt) || is2OutOf4DLT(dlt)) {
// Create:
// coordinates[l].push_back(coords[l])
// positions[l] = positions[l-1]
// <insert @ positions[l] at next level l + 1>
createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
coords[l]);
// coordinates[lvl].push_back(coords[lvl])
// positions[lvl] = positions[lvl-1]
// <insert @ positions[lvl] at next level lvl + 1>
createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
lvl, /*value=*/coords[lvl]);
} else {
assert(isDenseDLT(dlt));
// Construct the new position as:
// positions[l] = size * positions[l-1] + coords[l]
// <insert @ positions[l] at next level l + 1>
Value size = desc.getLvlSize(builder, loc, l);
// positions[lvl] = size * positions[lvl-1] + coords[lvl]
// <insert @ positions[lvl] at next level lvl + 1>
Value size = desc.getLvlSize(builder, loc, lvl);
Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
}
}
// Reached the actual value append/insert.
Expand All @@ -526,7 +531,6 @@ class SparseInsertGenerator
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
constexpr const char kInsertFuncNamePrefix[] = "_insert_";
const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));

SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
nameOstream << kInsertFuncNamePrefix;
Expand All @@ -543,8 +547,8 @@ class SparseInsertGenerator
// Static dim sizes are used in the generated code while dynamic sizes are
// loaded from the dimSizes buffer. This is the reason for adding the shape
// to the function name.
for (const auto sh : stt.getDimShape())
nameOstream << sh << "_";
for (const auto sz : stt.getDimShape())
nameOstream << sz << "_";
// Permutation information is also used in generating insertion.
if (!stt.isIdentity())
nameOstream << stt.getDimToLvl() << "_";
Expand Down Expand Up @@ -607,7 +611,6 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
assert(retOffset < newCall.getNumResults());
auto retType = ret.getType();
if (failed(typeConverter->convertType(retType, sparseFlat)))
// This should never happen.
llvm_unreachable("Failed to convert type in sparse tensor codegen");

// Converted types can not be empty when the type conversion succeed.
Expand Down Expand Up @@ -755,9 +758,7 @@ class SparseTensorAllocConverter
const auto resType = getSparseTensorType(op);
if (!resType.hasEncoding())
return failure();

// Construct allocation for each field.
const Location loc = op.getLoc();
Location loc = op.getLoc();
if (op.getCopy()) {
auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
SmallVector<Value> fields;
Expand All @@ -778,18 +779,18 @@ class SparseTensorAllocConverter
return success();
}

const Value sizeHint = op.getSizeHint();
const ValueRange dynSizes = adaptor.getDynamicSizes();
// Construct allocation for each field.
Value sizeHint = op.getSizeHint();
ValueRange dynSizes = adaptor.getDynamicSizes();
const size_t found = dynSizes.size();
const int64_t expected = resType.getNumDynamicDims();
if (found != static_cast<size_t>(expected))
return rewriter.notifyMatchFailure(
op, llvm::formatv(
"Got wrong number of dynamic sizes: Found={0}, Expected={1}",
found, expected));
return rewriter.notifyMatchFailure(op,
"Got wrong number of dynamic sizes");
SmallVector<Value> fields;
createAllocFields(rewriter, loc, resType, dynSizes,
enableBufferInitialization, fields, sizeHint);

// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
return success();
Expand Down Expand Up @@ -817,19 +818,18 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
return failure();

// Construct allocation for each field.
const Location loc = op.getLoc();
const Value sizeHint; // none
Location loc = op.getLoc();
Value sizeHint; // none
const ValueRange dynSizes = adaptor.getDynamicSizes();
const size_t found = dynSizes.size();
const int64_t expected = resType.getNumDynamicDims();
if (found != static_cast<size_t>(expected))
return rewriter.notifyMatchFailure(
op, llvm::formatv(
"Got wrong number of dynamic sizes: Found={0}, Expected={1}",
found, expected));
return rewriter.notifyMatchFailure(op,
"Got wrong number of dynamic sizes");
SmallVector<Value> fields;
createAllocFields(rewriter, loc, resType, dynSizes,
enableBufferInitialization, fields, sizeHint);

// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
return success();
Expand Down Expand Up @@ -1496,7 +1496,6 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
SmallVector<Value> fields;
createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
fields, nse);
MutSparseTensorDescriptor desc(dstTp, fields);

// Now construct the dim2lvl and lvl2dim buffers.
Value dim2lvlBuffer;
Expand All @@ -1505,6 +1504,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
dim2lvlBuffer, lvl2dimBuffer);

// Read the COO tensor data.
MutSparseTensorDescriptor desc(dstTp, fields);
Value xs = desc.getAOSMemRef();
Value ys = desc.getValMemRef();
const Type boolTp = rewriter.getIntegerType(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,13 @@ class SparseTensorAllocConverter
LogicalResult
matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getCopy())
return rewriter.notifyMatchFailure(op,
"sparse tensor copy not implemented");
Location loc = op.getLoc();
const auto stt = getSparseTensorType(op);
if (!stt.hasEncoding())
return failure();
if (op.getCopy())
return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
// Gather all dimension sizes as SSA values.
Location loc = op.getLoc();
const Dimension dimRank = stt.getDimRank();
SmallVector<Value> dimSizes;
dimSizes.reserve(dimRank);
Expand Down
Loading

0 comments on commit 160d483

Please sign in to comment.