Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmNeon] Updates LowerContractionToSMMLAPattern with vecmat unroll patterns #86005

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,41 +40,45 @@ static Type matchContainerType(Type element, Type container) {

/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
/// smmla instruction is emitted.
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
/// necessary, a single smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Check index maps that represent M N K in contract.
auto indexingMaps = op.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
affineMap.getNumResults() != 2;
})) {
return failure();
}
// Check iterator types for contract.
auto iteratorTypes = op.getIteratorTypesArray();
if (iteratorTypes.size() != 3 ||
iteratorTypes[0] != vector::IteratorType::parallel ||
iteratorTypes[1] != vector::IteratorType::parallel ||
iteratorTypes[2] != vector::IteratorType::reduction) {
return failure();
}
// Infer tile sizes from operands; Note: RHS is not transposed.
// Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
// Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
auto dimM = lhsType.getDimSize(0);
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = lhsType.getDimSize(1);

auto dimK = rhsType.getDimSize(1);
bool isVecmat = dimM == 1 ? true : false;
if (lhsType.getDimSize(lhsType.getRank() - 1) !=
rhsType.getDimSize(rhsType.getRank() - 1)) {
return failure(); // dimK mismatch
}
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
return failure();
}

// Check iterator types for contract. All iterators except inner-most
// dimension must be parallel.
auto iteratorTypes = op.getIteratorTypesArray();
if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
vector::IteratorType::reduction) {
KoolJBlack marked this conversation as resolved.
Show resolved Hide resolved
return failure();
}
if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
[](vector::IteratorType iteratorType) {
return iteratorType != vector::IteratorType::parallel;
})) {
return failure();
}

Expand Down Expand Up @@ -120,11 +124,14 @@ class LowerContractionToSMMLAPattern
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));

SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
SmallVector<int64_t> smmlaShape{2, 2, 8};
SmallVector<int64_t> loopOrder{0, 1, 2};
SmallVector<int64_t> smmlaShape{2, 8};
SmallVector<int64_t> loopOrder{0, 1};
if (unrolledSize.size() == 3) {
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
loopOrder.push_back(2);
}
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {

// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
Expand All @@ -150,16 +157,40 @@ class LowerContractionToSMMLAPattern
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);

auto inputElementType =
tiledLhs.getType().cast<ShapedType>().getElementType();
auto accElementType =
tiledAcc.getType().cast<ShapedType>().getElementType();
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
auto outputExpandedType = VectorType::get({2, 2}, accElementType);

// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
// rows along dimM. Expand their shapes to match the smmla op.
if (isVecmat) {
auto expandForSMMLA = [&](Value tiledOperand,
VectorType expandedTypeType) {
auto emptyOperand = rewriter.create<arith::ConstantOp>(
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
SmallVector<int64_t> offsets(
emptyOperand.getType().cast<ShapedType>().getRank(), 0);
SmallVector<int64_t> strides(
tiledOperand.getType().cast<ShapedType>().getRank(), 1);
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledOperand, emptyOperand, offsets, strides);
};
tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
}

// Collapse tiled operands to 1D vectors required by smmla intrinsic
auto collapsedInputType = VectorType::get(
tiledLhs.getType().cast<ShapedType>().getNumElements(),
tiledLhs.getType().cast<ShapedType>().getElementType());
auto collapsedOutputType = VectorType::get(
{4}, tiledAcc.getType().cast<ShapedType>().getElementType());
auto collapsedInputType =
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
auto collapsedOutputType =
VectorType::get(outputExpandedType.getNumElements(), accElementType);
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);

Expand All @@ -172,6 +203,11 @@ class LowerContractionToSMMLAPattern
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);

// With vecmat, only one row of tiled ACC can be inserted inot file result
if (isVecmat) {
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
}

// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
Expand Down
Loading
Loading