Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA #65621

Merged
merged 4 commits into from
Sep 14, 2023

Conversation

c-rhodes
Copy link
Collaborator

@c-rhodes c-rhodes commented Sep 7, 2023

This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:

vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16>
vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32>
vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64>

The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA
(non-widening) [2].

Note at the ISA level these variants are implemented by different
architecture features, these are listed below:

FMOPA (non-widening)
* half-precision - +sme2p1,+sme-f16f16
* single-precision - +sme
* double-precision - +sme-f64f64
BFMOPA (non-widening)
* half-precision - +sme2p1,+b16b16

There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.

Masking is currently unsupported.

Depends on #65450.

[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-

This patch changes vector type conversion to return failure on n-D
scalable vector types instead of asserting.

This is an alternative approach to llvm#65261 that aims to enable lowering
of Vector ops directly to ArmSME intrinsics where possible, and seems
more consistent with other type conversions. It's trivial to hit the
assert at the moment and it could be interpreted as n-D scalable vector
types being a bug, when they're valid types in the Vector dialect.

By returning failure it will generally fail more gracefully,
particularly for release builds or other builds where assertions are
disabled.
This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:

  vector<[8]xf16>,  vector<[8]xf16>  -> vector<[8]x[8]xf16>
  vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
  vector<[4]xf32>,  vector<[4]xf32>  -> vector<[4]x[4]xf32>
  vector<[2]xf64>,  vector<[2]xf64>  -> vector<[2]x[2]xf64>

The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA
(non-widening) [2].

Note at the ISA level these variants are implemented by different
architecture features, these are listed below:

  FMOPA (non-widening)
    * half-precision   - +sme2p1,+sme-f16f16
    * single-precision - +sme
    * double-precision - +sme-f64f64
  BFMOPA (non-widening)
    * half-precision   - +sme2p1,+b16b16

There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.

Masking is currently unsupported.

Depends on llvm#65450.

[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
@c-rhodes c-rhodes requested review from a team as code owners September 7, 2023 15:35
@github-actions github-actions bot added mlir:core MLIR Core Infrastructure mlir:vectorops mlir labels Sep 7, 2023
@c-rhodes c-rhodes requested a review from a team as a code owner September 8, 2023 08:51
@banach-space
Copy link
Contributor

CC @gaofangfrank

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % the comments below, thanks!

///
/// is converted to:
///
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two have the same name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.

How do we plan to implement proper support for this? I see two options:

  1. In one shot, we search for the two masks in the use-def chain and use them directly in the intrinsic. If there is any mask manipulation operation in-between, it should become dead, hopefully, and go away.
  2. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.

Thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two have the same name?

If you're referring to the predicate (?) that's because they're the same, both all active

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.

How do we plan to implement proper support for this? I see two options:

1. In one shot, we search for the two masks in the use-def chain and use them directly in the intrinsic. If there is any mask manipulation operation in-between, it should become dead, hopefully, and go away.

2. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.

Thoughts?

The only examples I've seen of masking (from grepping around the codebase) are where the mask is applied to the result of the outerproduct e.g. vector.mask { vector.outerproduct ... }, I just figured we'd need some way to correlate this to the inputs, but hadn't given it much thought.

Appreciate your input, I'll add a custom op that way there's more flexibility when it comes to masking, and will also look into how it would be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the expected complexity with generating correct masks, I am also leaning towards a custom op. Having said that, IMHO this PR is fine as is and we could iterate in the follow-up patches.

  1. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I guess that for this to work, we'd need something like `

 %res = arm_sme.op %rhs, %lhs <optional_mask_for_rhs_or_result> <optional_mask_for_lhs>

So, we'd allow 2 optional masks, both of which would be optional:

  • if only 1 mask is specified then this is a mask for the result (1 x 2D),
  • if 2 masks are specified then these are for the input vectors 2 x 1D),
  • if no masks are specified, then use ptrue (all lanes are active).

WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We perhaps could use vector.mask for masking the result so that we don't have to disambiguate the semantics based on the number of masks...

loc, rewriter.getIntegerType(elementWidth), acc);

// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a ConstantMaskOp here

Copy link
Member

@MacDue MacDue Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConstantMaskOp can only make all false masks for scalable vectors.

// Only zero sizes are accepted here:
vector.constant_mask [0] : vector<[4]xi1>

Could maybe use CreateMaskOp, but I'm not sure if it's much simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like a fairly important gap for us to fill, but not necessarily in this patch.

@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir-vector

Changes This patch adds support for lowering vector.outerproduct to the ArmSME MOPA intrinsic for the following types:

vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16>
vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32>
vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64>

The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA
(non-widening) [2].

Note at the ISA level these variants are implemented by different
architecture features, these are listed below:

FMOPA (non-widening)
* half-precision - +sme2p1,+sme-f16f16
* single-precision - +sme
* double-precision - +sme-f64f64
BFMOPA (non-widening)
* half-precision - +sme2p1,+b16b16

There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.

Masking is currently unsupported.

Depends on #65450.

[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-

Patch is 26.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65621.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+2)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+11-8)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+113-4)
  • (modified) mlir/lib/Dialect/ArmSME/Utils/Utils.cpp (-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+4-1)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+106-1)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+116)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+77)

<pre>
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index ed174699314e8d9..2a4327535c68750 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter {
Type convertMemRefToBarePtr(BaseMemRefType type) const;

/// Convert a 1D vector type into an LLVM vector type.

  • Type convertVectorType(VectorType type) const;
  • FailureOr&lt;Type&gt; convertVectorType(VectorType type) const;

    /// Options for customizing the llvm lowering.
    LowerToLLVMOptions options;
    diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
    index 554b9f119230667..9e8ad48b3c2db94 100644
    --- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
    +++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
    @@ -20,6 +20,8 @@
    namespace mlir {
    namespace arm_sme {

+constexpr unsigned MinStreamingVectorLengthInBits = 128;
+
/// Return minimum number of elements for the given element type in
/// a vector of SVL bits.
unsigned getSMETileSliceMinNumElts(Type type);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index a9e7ce9d42848b5..49e0513e629d951 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addConversion([&amp;](MemRefType type) { return convertMemRefType(type); });
addConversion(
[&amp;](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });

  • addConversion([&amp;](VectorType type) { return convertVectorType(type); });
  • addConversion([&amp;](VectorType type) -&gt; std::optional&lt;Type&gt; {

  • FailureOr&lt;Type&gt; llvmType = convertVectorType(type);

  • if (failed(llvmType))

  •  return std::nullopt;
    
  • return llvmType;

  • });

    // LLVM-compatible types are legal, so add a pass-through conversion. Do this
    // before the conversions below since conversions are attempted in reverse
    @@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
    /// * 1-D vector&amp;lt;axT&amp;gt; remains as is while,
    /// * n&gt;1 vector&amp;lt;ax...xkxT&amp;gt; convert via an (n-1)-D array type to
    /// !llvm.array&amp;lt;ax...array&amp;lt;jxvector&amp;lt;kxT&amp;gt;&amp;gt;&amp;gt;.
    -/// As LLVM does not support arrays of scalable vectors, it is assumed that
    -/// scalable vectors are always 1-D. This condition could be relaxed once the
    -/// missing functionality is added in LLVM
    -Type LLVMTypeConverter::convertVectorType(VectorType type) const {
    +/// Returns failure for n-D scalable vector types as LLVM does not support
    +/// arrays of scalable vectors.
    +FailureOr&lt;Type&gt; LLVMTypeConverter::convertVectorType(VectorType type) const {
    auto elementType = convertType(type.getElementType());
    if (!elementType)
    return {};
    @@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
    type.getScalableDims().back());
    assert(LLVM::isCompatibleVectorType(vectorType) &amp;&amp;
    &quot;expected vector type compatible with the LLVM dialect&quot;);

  • assert(
  •  (!type.isScalable() || (type.getRank() == 1)) &amp;amp;&amp;amp;
    
  •  &amp;quot;expected 1-D scalable vector (n-D scalable vectors are not supported)&amp;quot;);
    
  • if (type.isScalable() &amp;&amp; (type.getRank() &gt; 1))
  • return failure();
    auto shape = type.getShape();
    for (int i = shape.size() - 2; i &gt;= 0; --i)
    vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
    diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    index 685f8d57f76f52c..6c8843fbb4546e6 100644
    --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    @@ -361,6 +361,112 @@ struct MoveVectorToTileSliceToArmSMELowering
    }
    };

+/// Lower vector.outerproduct to SME MOPA intrinsics.
+///
+/// Example:
+///
+/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;add&gt;}
+/// : vector&lt;[4]xf32&gt;, vector&lt;[4]xf32&gt;
+///
+/// is converted to:
+///
+/// &quot;arm_sme.intr.mopa&quot;(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
+/// : (i32, vector&lt;[4]xi1&gt;, vector&lt;[4]xi1&gt;, vector&lt;[4]xf32&gt;,
+/// vector&lt;[4]xf32&gt;) -&gt; ()
+///
+/// Currently only supports FMOPA and BFMOPA (non-widening).
+struct VectorOuterProductToArmSMELowering

  • : public ConvertOpToLLVMPattern&lt;vector::OuterProductOp&gt; {
  • using ConvertOpToLLVMPattern&lt;vector::OuterProductOp&gt;::ConvertOpToLLVMPattern;
  • LogicalResult
  • matchAndRewrite(vector::OuterProductOp outerProductOp,
  •              vector::OuterProductOp::Adaptor adaptor,
    
  •              ConversionPatternRewriter &amp;amp;rewriter) const override {
    
  • auto isSupportedType = [](VectorType vectorType) {
  •  // TODO: the FP outer product instruction variants are predicated on
    
  •  // different features [1]:
    
  •  //
    
  •  // * FMOPA (non-widening)
    
  •  //   * half-precision   - +sme2p1,+sme-f16f16
    
  •  //   * single-precision - +sme
    
  •  //   * double-precision - +sme-f64f64
    
  •  // * BFMOPA
    
  •  //   * half-precision   - +sme2p1,+b16b16
    
  •  //
    
  •  // It should be possible to control lowering based on target features.
    
  •  // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
    
  •  if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
    
  •    return false;
    
  •  auto elementType = vectorType.getElementType();
    
  •  if (!elementType.isF16() &amp;amp;&amp;amp; !elementType.isBF16() &amp;amp;&amp;amp;
    
  •      !elementType.isF32() &amp;amp;&amp;amp; !elementType.isF64())
    
  •    return false;
    
  •  unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
    
  •                        vectorType.getElementTypeBitWidth();
    
  •  if (vectorType.getShape() != ArrayRef&amp;lt;int64_t&amp;gt;({minNumElts, minNumElts}))
    
  •    return false;
    
  •  return true;
    
  • };
  • auto resultVectorType = outerProductOp.getResultVectorType();
  • if (!isSupportedType(resultVectorType))
  •  return outerProductOp.emitError(&amp;quot;unsupported type&amp;quot;);
    
  • vector::CombiningKind kind = outerProductOp.getKind();
  • if (kind != vector::CombiningKind::ADD)
  •  // TODO: support subtract.
    
  •  return outerProductOp.emitError(&amp;quot;unsupported kind&amp;quot;);
    
  • auto maskableOp =
  •    cast&amp;lt;vector::MaskableOpInterface&amp;gt;(outerProductOp.getOperation());
    
  • if (maskableOp.isMasked())
  •  // TODO: support masking.
    
  •  return outerProductOp.emitError(&amp;quot;masking is currently unsupported&amp;quot;);
    
  • if (!isa&lt;VectorType&gt;(outerProductOp.getOperandTypeRHS()))
  •  // AXPY operation not suited for SME.
    
  •  return failure();
    
  • auto loc = outerProductOp.getLoc();
  • Value acc = outerProductOp.getAcc();
  • if (!acc)
  •  // Initalize accumulator with zero.
    
  •  acc = rewriter.create&amp;lt;arm_sme::ZeroOp&amp;gt;(loc, resultVectorType);
    
  • unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
  • auto tileId = rewriter.create&lt;arm_sme::CastVectorToTile&gt;(
  •    loc, rewriter.getIntegerType(elementWidth), acc);
    
  • // Create all active predicate mask.
  • auto one = rewriter.create&lt;arith::ConstantOp&gt;(
  •    loc, rewriter.getI1Type(),
    
  •    rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
    
  • auto predTy =
  •    VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
    
  •                    /*scalableDims=*/{true});
    
  • auto allActiveMask = rewriter.create&lt;vector::SplatOp&gt;(loc, predTy, one);
  • auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
  • // Create &#x27;arm_sme.intr.mopa&#x27; outer product intrinsic.
  • rewriter.create&lt;arm_sme::aarch64_sme_mopa&gt;(
  •    loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
    
  •    outerProductOp.getRhs());
    
  • // Create CastTileToVectorOp to use as the output.
  • rewriter.replaceOpWithNewOp&lt;arm_sme::CastTileToVector&gt;(
  •    outerProductOp, resultVectorType, tileId);
    
  • return success();
  • }
    +};

} // namespace

void mlir::configureArmSMELegalizeForExportTarget(
@@ -374,8 +480,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,

  •  arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable&amp;gt;();
    
  •  arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
    
  •  arm_sme::aarch64_sme_za_disable&amp;gt;();
    

    target.addLegalOp&lt;GetTileID&gt;();

  • target.addIllegalOp&lt;vector::OuterProductOp&gt;();

    // Mark &#x27;func.func&#x27; ops as legal if either:
    // 1. no &#x27;arm_za&#x27; function attribute is present.
    @@ -405,7 +513,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
    void mlir::populateArmSMELegalizeForLLVMExportPatterns(
    LLVMTypeConverter &amp;converter, RewritePatternSet &amp;patterns) {
    patterns.add&lt;EnableZAPattern, DisableZAPattern&gt;(patterns.getContext());

  • patterns.add&lt;ZeroOpConversion, StoreTileSliceToArmSMELowering,
  •           LoadTileSliceToArmSMELowering,
    
  •           MoveVectorToTileSliceToArmSMELowering&amp;gt;(converter);
    
  • patterns
  •  .add&amp;lt;ZeroOpConversion, StoreTileSliceToArmSMELowering,
    
  •       LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
    
  •       VectorOuterProductToArmSMELowering&amp;gt;(converter);
    

}
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index 8b2be7bc1901b9a..b8a47951cc7bbba 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -17,8 +17,6 @@
using namespace mlir;
using namespace mlir::arm_sme;

-static constexpr unsigned MinStreamingVectorLengthInBits = 128;

unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) &amp;&amp; &quot;invalid tile type!&quot;);
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index b66077372164e79..95a010dd59d95bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1121,11 +1121,14 @@ class OuterProductOpLowering : public OpRewritePattern&lt;vector::OuterProductOp&gt; {

LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &amp;rewriter) const override {

  • VectorType resType = op.getResultVectorType();

  • if ((resType.getShape().size() &gt;= 2) &amp;&amp; resType.allDimsScalable())

  •  return failure();
    
  • auto loc = op.getLoc();

    VectorType lhsType = op.getOperandVectorTypeLHS();
    VectorType rhsType = dyn_cast&lt;VectorType&gt;(op.getOperandTypeRHS());

  • VectorType resType = op.getResultVectorType();
    Type eltType = resType.getElementType();
    bool isInt = isa&lt;IntegerType, IndexType&gt;(eltType);
    Value acc = op.getAcc();
    diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
    index af528295ef6ee23..687ef79385334cf 100644
    --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
    +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
    @@ -1,4 +1,8 @@
    -// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm=&quot;enable-arm-sme&quot; -cse -canonicalize -split-input-file | FileCheck %s
    +// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm=&quot;enable-arm-sme&quot; -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s

+//===----------------------------------------------------------------------===//
+// vector.transfer_write
+//===----------------------------------------------------------------------===//

// CHECK-LABEL: @transfer_write_2d_zero_i8(
// CHECK-SAME: %[[ARG0:.*]]: memref&lt;?x?xi8&gt;)
@@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref&lt;?x?xi8&gt;) {
return
}

+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
// -----

// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
@@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref&lt;?x?xi128&gt;) -&gt; vector&lt;[1]x[1]xi128&gt; {
return %tile : vector&lt;[1]x[1]xi128&gt;
}

+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
// -----

// CHECK-LABEL: @vector_store_i8(
@@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector&lt;[1]x[1]xi128&gt;, %arg0 : memref&lt;?x?xi1
vector.store %tile, %arg0[%c0, %c0] : memref&lt;?x?xi128&gt;, vector&lt;[1]x[1]xi128&gt;
return
}
+
+//===----------------------------------------------------------------------===//
+// vector.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f16
+// CHECK-SAME: (%[[LHS:.]]: vector&lt;[8]xf16&gt;, %[[RHS:.]]: vector&lt;[8]xf16&gt;, %[[ACC:.*]]: vector&lt;[8]x[8]xf16&gt;)
+func.func @vector_outerproduct_add_f16(%lhs : vector&lt;[8]xf16&gt;, %rhs : vector&lt;[8]xf16&gt;, %acc : vector&lt;[8]x[8]xf16&gt;) {

  • // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense&lt;true&gt; : vector&lt;[8]xi1&gt;
  • // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector&lt;[8]x[8]xf16&gt; to i16
  • // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
  • // CHECK: &quot;arm_sme.intr.mopa&quot;(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector&lt;[8]xi1&gt;, vector&lt;[8]xi1&gt;, vector&lt;[8]xf16&gt;, vector&lt;[8]xf16&gt;)
  • %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;add&gt;} : vector&lt;[8]xf16&gt;, vector&lt;[8]xf16&gt;
  • &quot;prevent.dce&quot;(%0) : (vector&lt;[8]x[8]xf16&gt;) -&gt; ()
    +}

+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_bf16
+func.func @vector_outerproduct_add_bf16(%lhs : vector&lt;[8]xbf16&gt;, %rhs : vector&lt;[8]xbf16&gt;, %acc : vector&lt;[8]x[8]xbf16&gt;) {

  • // CHECK: &quot;arm_sme.intr.mopa&quot;({{.}}, {{.}}, {{.*}}) : (i32, vector&lt;[8]xi1&gt;, vector&lt;[8]xi1&gt;, vector&lt;[8]xbf16&gt;, vector&lt;[8]xbf16&gt;)
  • %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;add&gt;} : vector&lt;[8]xbf16&gt;, vector&lt;[8]xbf16&gt;
  • &quot;prevent.dce&quot;(%0) : (vector&lt;[8]x[8]xbf16&gt;) -&gt; ()
    +}

+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f32
+func.func @vector_outerproduct_add_f32(%lhs : vector&lt;[4]xf32&gt;, %rhs : vector&lt;[4]xf32&gt;, %acc : vector&lt;[4]x[4]xf32&gt;) {

  • // CHECK-NOT: arith.extui
  • // CHECK-NOT: arith.trunci
  • // CHECK: &quot;arm_sme.intr.mopa&quot;({{.}}, {{.}}, {{.*}}) : (i32, vector&lt;[4]xi1&gt;, vector&lt;[4]xi1&gt;, vector&lt;[4]xf32&gt;, vector&lt;[4]xf32&gt;)
  • %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;add&gt;} : vector&lt;[4]xf32&gt;, vector&lt;[4]xf32&gt;
  • &quot;prevent.dce&quot;(%0) : (vector&lt;[4]x[4]xf32&gt;) -&gt; ()
    +}

+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f64
+func.func @vector_outerproduct_add_f64(%lhs : vector&lt;[2]xf64&gt;, %rhs : vector&lt;[2]xf64&gt;, %acc : vector&lt;[2]x[2]xf64&gt;) {

  • // CHECK: arith.trunci {{.*}} : i64 to i32
  • // CHECK: &quot;arm_sme.intr.mopa&quot;({{.}}, {{.}}, {{.*}}) : (i32, vector&lt;[2]xi1&gt;, vector&lt;[2]xi1&gt;, vector&lt;[2]xf64&gt;, vector&lt;[2]xf64&gt;)
  • %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;add&gt;} : vector&lt;[2]xf64&gt;, vector&lt;[2]xf64&gt;
  • &quot;prevent.dce&quot;(%0) : (vector&lt;[2]x[2]xf64&gt;) -&gt; ()
    +}

+// -----
+
+// CHECK-LABEL: @vector_outerproduct_no_accumulator
+func.func @vector_outerproduct_no_accumulator(%lhs : vector&lt;[2]xf64&gt;, %rhs : vector&lt;[2]xf64&gt;) {

  • // CHECK: &quot;arm_sme.intr.zero&quot;({{.*}}) : (i32) -&gt; ()
  • // CHECK: &quot;arm_sme.intr.mopa&quot;({{.}}, {{.}}, {{.*}}) : (i32, vector&lt;[2]xi1&gt;, vector&lt;[2]xi1&gt;, vector&lt;[2]xf64&gt;, vector&lt;[2]xf64&gt;)
  • %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind&lt;add&gt;} : vector&lt;[2]xf64&gt;, vector&lt;[2]xf64&gt;
  • &quot;prevent.dce&quot;(%0) : (vector&lt;[2]x[2]xf64&gt;) -&gt; ()
    +}

+// -----
+
+// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
+func.func @vector_outerproduct_unsupported_axpy(%lhs : vector&lt;[2]xf64&gt;, %rhs : f64, %acc : vector&lt;[2]xf64&gt;) -&gt; vector&lt;[2]xf64&gt; {

  • // CHECK-NOT: arm_sme
  • %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;mul&gt;} : vector&lt;[2]xf64&gt;, f64
  • return %0 : vector&lt;[2]xf64&gt;
    +}

+// -----
+
+func.func @vector_outerproduct_unsupported_type(%lhs : vector&lt;[16]xi8&gt;, %rhs : vector&lt;[16]xi8&gt;, %acc : vector&lt;[16]x[16]xi8&gt;) {

  • // expected-error@+2 {{failed to legalize operation &#x27;vector.outerproduct&#x27;}}
  • // expected-error@+1 {{unsupported type}}
  • %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;add&gt;} : vector&lt;[16]xi8&gt;, vector&lt;[16]xi8&gt;
  • &quot;prevent.dce&quot;(%0) : (vector&lt;[16]x[16]xi8&gt;) -&gt; ()
    +}

+// -----
+
+func.func @vector_outerproduct_unsupported_kind(%lhs : vector&lt;[2]xf64&gt;, %rhs : vector&lt;[2]xf64&gt;, %acc : vector&lt;[2]x[2]xf64&gt;) {

  • // expected-error@+2 {{failed to legalize operation &#x27;vector.outerproduct&#x27;}}
  • // expected-error@+1 {{unsupported kind}}
  • %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;mul&gt;} : vector&lt;[2]xf64&gt;, vector&lt;[2]xf64&gt;
  • &quot;prevent.dce&quot;(%0) : (vector&lt;[2]x[2]xf64&gt;) -&gt; ()
    +}

+// -----
+
+func.func @vector_outerproduct_add_masked_f32(%lhs : vector&lt;[4]xf32&gt;, %rhs : vector&lt;[4]xf32&gt;, %acc : vector&lt;[4]x[4]xf32&gt;, %mask : vector&lt;[4]x[4]xi1&gt;) {

  • // expected-error@+2 {{failed to legalize operation &#x27;vector.outerproduct&#x27;}}
  • // expected-error@+1 {{masking is currently unsupported}}
  • %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind&lt;add&gt;} : vector&lt;[4]xf32&gt;, vector&lt;[4]xf32&gt; } : vector&lt;[4]x[4]xi1&gt; -&gt; vector&lt;[4]x[4]xf32&gt;
  • &quot;prevent.dce&quot;(%0) : (vector&lt;[4]x[4]xf32&gt;) -&gt; ()
    +}
    diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
    new file mode 100644
    index 000000000000000..00f1f6fd3fa8e19
    --- /dev/null
    +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
    @@ -0,0 +1,116 @@
    +// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
    +// DEFINE: %{compile} = mlir-opt %s
    +// DEFINE: -enable-arm-streaming=&quot;mode=locally enable-za&quot;
    +// DEFINE: -convert-vector-to-...

@c-rhodes
Copy link
Collaborator Author

I'm currently looking at lowering linalg.matmul to vector.outerproduct ops and how to drive that lowering based on previous investigations from @banach-space. I'll be looking into masking after that, but in the meantime I think it would be ok to land this as is with a plan to add a custom op and support masking later.

Any thoughts?

@banach-space banach-space removed the request for review from a team September 13, 2023 14:55
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I"m OK with landing this as is. We know that masking is one large outstanding challenge that needs addressing, but no need to do it in this PR. We also expect that we will probably introduce a custom op for outer products, but that could be a follow-up step. This change, in its current form, is already a very valuable step towards enabling SME in MLIR.

@dcaballe , is that OK with you?

LGTM, great work!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG!

@c-rhodes c-rhodes merged commit f75d46a into llvm:main Sep 14, 2023
kstoimenov pushed a commit to kstoimenov/llvm-project that referenced this pull request Sep 14, 2023
This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:

  vector<[8]xf16>,  vector<[8]xf16>  -> vector<[8]x[8]xf16>
  vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
  vector<[4]xf32>,  vector<[4]xf32>  -> vector<[4]x[4]xf32>
  vector<[2]xf64>,  vector<[2]xf64>  -> vector<[2]x[2]xf64>

The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to
BFMOPA
(non-widening) [2].

Note at the ISA level these variants are implemented by different
architecture features, these are listed below:

  FMOPA (non-widening)
    * half-precision   - +sme2p1,+sme-f16f16
    * single-precision - +sme
    * double-precision - +sme-f64f64
  BFMOPA (non-widening)
    * half-precision   - +sme2p1,+b16b16

There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.

Masking is currently unsupported.

Depends on llvm#65450.

[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:

  vector<[8]xf16>,  vector<[8]xf16>  -> vector<[8]x[8]xf16>
  vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
  vector<[4]xf32>,  vector<[4]xf32>  -> vector<[4]x[4]xf32>
  vector<[2]xf64>,  vector<[2]xf64>  -> vector<[2]x[2]xf64>

The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to
BFMOPA
(non-widening) [2].

Note at the ISA level these variants are implemented by different
architecture features, these are listed below:

  FMOPA (non-widening)
    * half-precision   - +sme2p1,+sme-f16f16
    * single-precision - +sme
    * double-precision - +sme-f64f64
  BFMOPA (non-widening)
    * half-precision   - +sme2p1,+b16b16

There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.

Masking is currently unsupported.

Depends on llvm#65450.

[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants