-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA #65621
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -361,6 +361,111 @@ struct MoveVectorToTileSliceToArmSMELowering | |
} | ||
}; | ||
|
||
/// Lower `vector.outerproduct` to SME MOPA intrinsics. | ||
/// | ||
/// Example: | ||
/// | ||
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} | ||
/// : vector<[4]xf32>, vector<[4]xf32> | ||
/// | ||
/// is converted to: | ||
/// | ||
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs) | ||
/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, | ||
/// vector<[4]xf32>) -> () | ||
/// | ||
/// Currently only supports FMOPA and BFMOPA (non-widening). | ||
struct VectorOuterProductToArmSMELowering | ||
: public ConvertOpToLLVMPattern<vector::OuterProductOp> { | ||
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(vector::OuterProductOp outerProductOp, | ||
vector::OuterProductOp::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto isSupportedType = [](VectorType vectorType) { | ||
// TODO: the FP outer product instruction variants are predicated on | ||
// different features: | ||
// | ||
// * FMOPA (non-widening) | ||
// * half-precision - +sme2p1,+sme-f16f16 | ||
c-rhodes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// * single-precision - +sme | ||
// * double-precision - +sme-f64f64 | ||
// * BFMOPA | ||
// * half-precision - +sme2p1,+b16b16 | ||
// | ||
// It should be possible to control lowering based on target features. | ||
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable()) | ||
return false; | ||
|
||
auto elementType = vectorType.getElementType(); | ||
|
||
if (!elementType.isF16() && !elementType.isBF16() && | ||
!elementType.isF32() && !elementType.isF64()) | ||
return false; | ||
|
||
unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits / | ||
vectorType.getElementTypeBitWidth(); | ||
if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts})) | ||
return false; | ||
|
||
return true; | ||
}; | ||
|
||
auto resultVectorType = outerProductOp.getResultVectorType(); | ||
if (!isSupportedType(resultVectorType)) | ||
return outerProductOp.emitError("unsupported type"); | ||
|
||
vector::CombiningKind kind = outerProductOp.getKind(); | ||
if (kind != vector::CombiningKind::ADD) | ||
// TODO: support subtract. | ||
return outerProductOp.emitError("unsupported kind"); | ||
|
||
auto maskableOp = | ||
cast<vector::MaskableOpInterface>(outerProductOp.getOperation()); | ||
if (maskableOp.isMasked()) | ||
// TODO: support masking. | ||
return outerProductOp.emitError("masking is currently unsupported"); | ||
|
||
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS())) | ||
// AXPY operation not suited for SME. | ||
return failure(); | ||
|
||
auto loc = outerProductOp.getLoc(); | ||
|
||
Value acc = outerProductOp.getAcc(); | ||
if (!acc) | ||
// Initalize accumulator with zero. | ||
acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType); | ||
|
||
unsigned elementWidth = resultVectorType.getElementTypeBitWidth(); | ||
auto tileId = rewriter.create<arm_sme::CastVectorToTile>( | ||
loc, rewriter.getIntegerType(elementWidth), acc); | ||
|
||
// Create all active predicate mask. | ||
auto one = rewriter.create<arith::ConstantOp>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should use a ConstantMaskOp here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could maybe use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feels like a fairly important gap for us to fill, but not necessarily in this patch. |
||
loc, rewriter.getI1Type(), | ||
rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); | ||
auto predTy = | ||
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(), | ||
/*scalableDims=*/{true}); | ||
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one); | ||
|
||
auto tileI32 = castTileIDToI32(tileId, loc, rewriter); | ||
|
||
// Create 'arm_sme.intr.mopa' outer product intrinsic. | ||
rewriter.create<arm_sme::aarch64_sme_mopa>( | ||
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(), | ||
outerProductOp.getRhs()); | ||
|
||
// Create `CastTileToVectorOp` to use as the output. | ||
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>( | ||
outerProductOp, resultVectorType, tileId); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::configureArmSMELegalizeForExportTarget( | ||
|
@@ -374,8 +479,10 @@ void mlir::configureArmSMELegalizeForExportTarget( | |
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz, | ||
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz, | ||
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz, | ||
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>(); | ||
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable, | ||
arm_sme::aarch64_sme_za_disable>(); | ||
target.addLegalOp<GetTileID>(); | ||
target.addIllegalOp<vector::OuterProductOp>(); | ||
|
||
// Mark 'func.func' ops as legal if either: | ||
// 1. no 'arm_za' function attribute is present. | ||
|
@@ -405,7 +512,8 @@ void mlir::configureArmSMELegalizeForExportTarget( | |
void mlir::populateArmSMELegalizeForLLVMExportPatterns( | ||
LLVMTypeConverter &converter, RewritePatternSet &patterns) { | ||
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext()); | ||
patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering, | ||
LoadTileSliceToArmSMELowering, | ||
MoveVectorToTileSliceToArmSMELowering>(converter); | ||
patterns | ||
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering, | ||
LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering, | ||
VectorOuterProductToArmSMELowering>(converter); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these two have the same name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.
How do we plan to implement proper support for this? I see two options:
I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're referring to the predicate (?) that's because they're the same, both all active
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only examples I've seen of masking (from grepping around the codebase) are where the mask is applied to the result of the outerproduct e.g.
vector.mask { vector.outerproduct ... }
, I just figured we'd need some way to correlate this to the inputs, but hadn't given it much thought.Appreciate your input, I'll add a custom op that way there's more flexibility when it comes to masking, and will also look into how it would be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the expected complexity with generating correct masks, I am also leaning towards a custom op. Having said that, IMHO this PR is fine as is and we could iterate in the follow-up patches.
I guess that for this to work, we'd need something like `
So, we'd allow 2 optional masks, both of which would be optional:
ptrue
(all lanes are active).WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We perhaps could use
vector.mask
for masking the result so that we don't have to disambiguate the semantics based on the number of masks...