Skip to content

Commit

Permalink
[mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder (llvm#…
Browse files Browse the repository at this point in the history
…89954)

This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and
stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new
pattern resolves strided_metadata(collapse_shape) directly, without introduce a
reshape_cast op.
  • Loading branch information
dcaballe authored Apr 26, 2024
1 parent d74e42a commit 450ac01
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 60 deletions.
201 changes: 142 additions & 59 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,89 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
groupStrides)};
}

/// From `reshape_like(memref, subSizes, subStrides))` compute
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
/// extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subStrides#i
/// sizes = subSizes
/// \endverbatim
///
/// and return {baseBuffer, baseOffset, sizes, strides}
template <typename ReassociativeReshapeLikeOp>
static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
function_ref<SmallVector<OpFoldResult>(
ReassociativeReshapeLikeOp, OpBuilder &,
ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
getReshapedSizes,
function_ref<SmallVector<OpFoldResult>(
ReassociativeReshapeLikeOp, OpBuilder &,
ArrayRef<OpFoldResult> /*origSizes*/,
ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
getReshapedStrides) {
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = reshape.getLoc();
Value source = reshape.getSrc();
auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();

auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);

// Collect statically known information.
auto [strides, offset] = getStridesAndOffset(sourceType);
MemRefType reshapeType = reshape.getResultType();
unsigned reshapeRank = reshapeType.getRank();

OpFoldResult offsetOfr =
ShapedType::isDynamic(offset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
: rewriter.getIndexAttr(offset);

// Get the special case of 0-D out of the way.
if (sourceRank == 0) {
SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
/*sizes=*/ones, /*strides=*/ones};
}

SmallVector<OpFoldResult> finalSizes;
finalSizes.reserve(reshapeRank);
SmallVector<OpFoldResult> finalStrides;
finalStrides.reserve(reshapeRank);

// Compute the reshaped strides and sizes from the base strides and sizes.
SmallVector<OpFoldResult> origSizes =
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
SmallVector<OpFoldResult> origStrides =
getAsOpFoldResult(newExtractStridedMetadata.getStrides());
unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
for (; idx != endIdx; ++idx) {
SmallVector<OpFoldResult> reshapedSizes =
getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);

unsigned groupSize = reshapedSizes.size();
for (unsigned i = 0; i < groupSize; ++i) {
finalSizes.push_back(reshapedSizes[i]);
finalStrides.push_back(reshapedStrides[i]);
}
}
assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
(isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
"We should have visited all the input dimensions");
assert(finalSizes.size() == reshapeRank &&
"We should have populated all the values");

return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
finalSizes, finalStrides};
}

/// Replace `baseBuffer, offset, sizes, strides =
/// extract_strided_metadata(reshapeLike(memref))`
/// With
Expand Down Expand Up @@ -580,68 +663,65 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {

LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
PatternRewriter &rewriter) const override {
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = reshape.getLoc();
Value source = reshape.getSrc();
auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();

auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);

// Collect statically known information.
auto [strides, offset] = getStridesAndOffset(sourceType);
MemRefType reshapeType = reshape.getResultType();
unsigned reshapeRank = reshapeType.getRank();

OpFoldResult offsetOfr =
ShapedType::isDynamic(offset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
: rewriter.getIndexAttr(offset);

// Get the special case of 0-D out of the way.
if (sourceRank == 0) {
SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
offsetOfr, /*sizes=*/ones, /*strides=*/ones);
rewriter.replaceOp(reshape, memrefDesc.getResult());
return success();
FailureOr<StridedMetadata> stridedMetadata =
resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
rewriter, reshape, getReshapedSizes, getReshapedStrides);
if (failed(stridedMetadata)) {
return rewriter.notifyMatchFailure(reshape,
"failed to resolve reshape metadata");
}

SmallVector<OpFoldResult> finalSizes;
finalSizes.reserve(reshapeRank);
SmallVector<OpFoldResult> finalStrides;
finalStrides.reserve(reshapeRank);

// Compute the reshaped strides and sizes from the base strides and sizes.
SmallVector<OpFoldResult> origSizes =
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
SmallVector<OpFoldResult> origStrides =
getAsOpFoldResult(newExtractStridedMetadata.getStrides());
unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
for (; idx != endIdx; ++idx) {
SmallVector<OpFoldResult> reshapedSizes =
getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);

unsigned groupSize = reshapedSizes.size();
for (unsigned i = 0; i < groupSize; ++i) {
finalSizes.push_back(reshapedSizes[i]);
finalStrides.push_back(reshapedStrides[i]);
}
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
reshape, reshape.getType(), stridedMetadata->basePtr,
stridedMetadata->offset, stridedMetadata->sizes,
stridedMetadata->strides);
return success();
}
};

/// Pattern to replace `extract_strided_metadata(collapse_shape)`
/// With
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
/// extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subSizes#i
/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
/// sizes = subSizes
/// \verbatim
///
/// with `baseBuffer`, `offset`, `sizes` and `strides` being
/// the replacements for the original `extract_strided_metadata`.
struct ExtractStridedMetadataOpCollapseShapeFolder
: OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto collapseShapeOp =
op.getSource().getDefiningOp<memref::CollapseShapeOp>();
if (!collapseShapeOp)
return failure();

FailureOr<StridedMetadata> stridedMetadata =
resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
if (failed(stridedMetadata)) {
return rewriter.notifyMatchFailure(
op,
"failed to resolve metadata in terms of source collapse_shape op");
}
assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
(isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
"We should have visited all the input dimensions");
assert(finalSizes.size() == reshapeRank &&
"We should have populated all the values");
auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
offsetOfr, finalSizes, finalStrides);
rewriter.replaceOp(reshape, memrefDesc.getResult());

Location loc = collapseShapeOp.getLoc();
SmallVector<Value> results;
results.push_back(stridedMetadata->basePtr);
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
stridedMetadata->offset));
results.append(
getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
stridedMetadata->strides));
rewriter.replaceOp(op, results);
return success();
}
};
Expand Down Expand Up @@ -1018,9 +1098,11 @@ void memref::populateExpandStridedMetadataPatterns(
getCollapsedStride>,
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
Expand All @@ -1030,6 +1112,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
Expand Down
24 changes: 23 additions & 1 deletion mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1513,4 +1513,26 @@ func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index)
%sizes, %strides :
memref<f16,3>, index,
index, index
}
}

// -----

func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
-> (memref<f32>, index, index, index) {

%collapse = memref.collapse_shape %base[[0, 1]] :
memref<5x4xf32> into memref<20xf32>

%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %collapse :
memref<20xf32> -> memref<f32>, index, index, index

return %base_buffer, %offset, %size, %stride :
memref<f32>, index, index, index
}

// CHECK-LABEL: func @extract_strided_metadata_of_collapse_shape
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index

0 comments on commit 450ac01

Please sign in to comment.