-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA #65621
Conversation
This patch changes vector type conversion to return failure on n-D scalable vector types instead of asserting. This is an alternative approach to llvm#65261 that aims to enable lowering of Vector ops directly to ArmSME intrinsics where possible, and seems more consistent with other type conversions. It's trivial to hit the assert at the moment and it could be interpreted as n-D scalable vector types being a bug, when they're valid types in the Vector dialect. By returning failure it will generally fail more gracefully, particularly for release builds or other builds where assertions are disabled.
This patch adds support for lowering vector.outerproduct to the ArmSME MOPA intrinsic for the following types: vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16> vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16> vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32> vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64> The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA (non-widening) [2]. Note at the ISA level these variants are implemented by different architecture features, these are listed below: FMOPA (non-widening) * half-precision - +sme2p1,+sme-f16f16 * single-precision - +sme * double-precision - +sme-f64f64 BFMOPA (non-widening) * half-precision - +sme2p1,+b16b16 There's currently no way to target different features when lowering to ArmSME. Integration tests are added for F32 and F64. We use QEMU to run the integration tests but SME2 support isn't available yet, it's targeted for 9.0, so integration tests for these variants excluded. Masking is currently unsupported. Depends on llvm#65450. [1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate- [2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % the comments below, thanks!
/// | ||
/// is converted to: | ||
/// | ||
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these two have the same name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.
How do we plan to implement proper support for this? I see two options:
- In one shot, we search for the two masks in the use-def chain and use them directly in the intrinsic. If there is any mask manipulation operation in-between, it should become dead, hopefully, and go away.
- In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.
I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these two have the same name?
If you're referring to the predicate (?) that's because they're the same, both all active
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.
How do we plan to implement proper support for this? I see two options:
1. In one shot, we search for the two masks in the use-def chain and use them directly in the intrinsic. If there is any mask manipulation operation in-between, it should become dead, hopefully, and go away. 2. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.
I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.
Thoughts?
The only examples I've seen of masking (from grepping around the codebase) are where the mask is applied to the result of the outerproduct e.g. vector.mask { vector.outerproduct ... }
, I just figured we'd need some way to correlate this to the inputs, but hadn't given it much thought.
Appreciate your input, I'll add a custom op that way there's more flexibility when it comes to masking, and will also look into how it would be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the expected complexity with generating correct masks, I am also leaning towards a custom op. Having said that, IMHO this PR is fine as is and we could iterate in the follow-up patches.
- In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.
I guess that for this to work, we'd need something like `
%res = arm_sme.op %rhs, %lhs <optional_mask_for_rhs_or_result> <optional_mask_for_lhs>
So, we'd allow 2 optional masks, both of which would be optional:
- if only 1 mask is specified then this is a mask for the result (1 x 2D),
- if 2 masks are specified then these are for the input vectors 2 x 1D),
- if no masks are specified, then use
ptrue
(all lanes are active).
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We perhaps could use vector.mask
for masking the result so that we don't have to disambiguate the semantics based on the number of masks...
loc, rewriter.getIntegerType(elementWidth), acc); | ||
|
||
// Create all active predicate mask. | ||
auto one = rewriter.create<arith::ConstantOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use a ConstantMaskOp here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConstantMaskOp
can only make all false masks for scalable vectors.
// Only zero sizes are accepted here:
vector.constant_mask [0] : vector<[4]xi1>
Could maybe use CreateMaskOp
, but I'm not sure if it's much simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like a fairly important gap for us to fill, but not necessarily in this patch.
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
Outdated
Show resolved
Hide resolved
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir-vector ChangesThis patch adds support for lowering vector.outerproduct to the ArmSME MOPA intrinsic for the following types:vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16> The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA Note at the ISA level these variants are implemented by different FMOPA (non-widening) There's currently no way to target different features when lowering to Masking is currently unsupported. Depends on #65450. [1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
|
I'm currently looking at lowering Any thoughts? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I"m OK with landing this as is. We know that masking is one large outstanding challenge that needs addressing, but no need to do it in this PR. We also expect that we will probably introduce a custom op for outer products, but that could be a follow-up step. This change, in its current form, is already a very valuable step towards enabling SME in MLIR.
@dcaballe , is that OK with you?
LGTM, great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG!
This patch adds support for lowering vector.outerproduct to the ArmSME MOPA intrinsic for the following types: vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16> vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16> vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32> vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64> The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA (non-widening) [2]. Note at the ISA level these variants are implemented by different architecture features, these are listed below: FMOPA (non-widening) * half-precision - +sme2p1,+sme-f16f16 * single-precision - +sme * double-precision - +sme-f64f64 BFMOPA (non-widening) * half-precision - +sme2p1,+b16b16 There's currently no way to target different features when lowering to ArmSME. Integration tests are added for F32 and F64. We use QEMU to run the integration tests but SME2 support isn't available yet, it's targeted for 9.0, so integration tests for these variants excluded. Masking is currently unsupported. Depends on llvm#65450. [1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate- [2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
This patch adds support for lowering vector.outerproduct to the ArmSME MOPA intrinsic for the following types: vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16> vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16> vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32> vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64> The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA (non-widening) [2]. Note at the ISA level these variants are implemented by different architecture features, these are listed below: FMOPA (non-widening) * half-precision - +sme2p1,+sme-f16f16 * single-precision - +sme * double-precision - +sme-f64f64 BFMOPA (non-widening) * half-precision - +sme2p1,+b16b16 There's currently no way to target different features when lowering to ArmSME. Integration tests are added for F32 and F64. We use QEMU to run the integration tests but SME2 support isn't available yet, it's targeted for 9.0, so integration tests for these variants excluded. Masking is currently unsupported. Depends on llvm#65450. [1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate- [2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:
vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16>
vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32>
vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64>
The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA
(non-widening) [2].
Note at the ISA level these variants are implemented by different
architecture features, these are listed below:
FMOPA (non-widening)
* half-precision - +sme2p1,+sme-f16f16
* single-precision - +sme
* double-precision - +sme-f64f64
BFMOPA (non-widening)
* half-precision - +sme2p1,+b16b16
There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.
Masking is currently unsupported.
Depends on #65450.
[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-