Skip to content

Commit

Permalink
[mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA (llvm#65621)
Browse files Browse the repository at this point in the history
This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:

  vector<[8]xf16>,  vector<[8]xf16>  -> vector<[8]x[8]xf16>
  vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
  vector<[4]xf32>,  vector<[4]xf32>  -> vector<[4]x[4]xf32>
  vector<[2]xf64>,  vector<[2]xf64>  -> vector<[2]x[2]xf64>

The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to
BFMOPA
(non-widening) [2].

Note at the ISA level these variants are implemented by different
architecture features, these are listed below:

  FMOPA (non-widening)
    * half-precision   - +sme2p1,+sme-f16f16
    * single-precision - +sme
    * double-precision - +sme-f64f64
  BFMOPA (non-widening)
    * half-precision   - +sme2p1,+b16b16

There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.

Masking is currently unsupported.

Depends on llvm#65450.

[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
  • Loading branch information
c-rhodes authored and kstoimenov committed Sep 14, 2023
1 parent fc768d3 commit a2e12e8
Show file tree
Hide file tree
Showing 7 changed files with 418 additions and 8 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
namespace mlir {
namespace arm_sme {

constexpr unsigned MinStreamingVectorLengthInBits = 128;

/// Return minimum number of elements for the given element `type` in
/// a vector of SVL bits.
unsigned getSMETileSliceMinNumElts(Type type);
Expand Down
117 changes: 113 additions & 4 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,112 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};

/// Lower `vector.outerproduct` to SME MOPA intrinsics.
///
/// Example:
///
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
/// : vector<[4]xf32>, vector<[4]xf32>
///
/// is converted to:
///
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
/// vector<[4]xf32>) -> ()
///
/// Currently only supports FMOPA and BFMOPA (non-widening).
struct VectorOuterProductToArmSMELowering
: public ConvertOpToLLVMPattern<vector::OuterProductOp> {
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::OuterProductOp outerProductOp,
vector::OuterProductOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto isSupportedType = [](VectorType vectorType) {
// TODO: the FP outer product instruction variants are predicated on
// different features [1]:
//
// * FMOPA (non-widening)
// * half-precision - +sme2p1,+sme-f16f16
// * single-precision - +sme
// * double-precision - +sme-f64f64
// * BFMOPA
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
// [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;

auto elementType = vectorType.getElementType();

if (!elementType.isF16() && !elementType.isBF16() &&
!elementType.isF32() && !elementType.isF64())
return false;

unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
vectorType.getElementTypeBitWidth();
if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
return false;

return true;
};

auto resultVectorType = outerProductOp.getResultVectorType();
if (!isSupportedType(resultVectorType))
return outerProductOp.emitError("unsupported type");

vector::CombiningKind kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
// TODO: support subtract.
return outerProductOp.emitError("unsupported kind");

auto maskableOp =
cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
if (maskableOp.isMasked())
// TODO: support masking.
return outerProductOp.emitError("masking is currently unsupported");

if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
// AXPY operation not suited for SME.
return failure();

auto loc = outerProductOp.getLoc();

Value acc = outerProductOp.getAcc();
if (!acc)
// Initalize accumulator with zero.
acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);

unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(elementWidth), acc);

// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy =
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);

auto tileI32 = castTileIDToI32(tileId, loc, rewriter);

// Create 'arm_sme.intr.mopa' outer product intrinsic.
rewriter.create<arm_sme::aarch64_sme_mopa>(
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
outerProductOp.getRhs());

// Create `CastTileToVectorOp` to use as the output.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
outerProductOp, resultVectorType, tileId);

return success();
}
};

} // namespace

void mlir::configureArmSMELegalizeForExportTarget(
Expand All @@ -374,8 +480,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();

// Mark 'func.func' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
Expand Down Expand Up @@ -405,7 +513,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
LoadTileSliceToArmSMELowering,
MoveVectorToTileSliceToArmSMELowering>(converter);
patterns
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
VectorOuterProductToArmSMELowering>(converter);
}
2 changes: 0 additions & 2 deletions mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
using namespace mlir;
using namespace mlir::arm_sme;

static constexpr unsigned MinStreamingVectorLengthInBits = 128;

unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) && "invalid tile type!");
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1122,11 +1122,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {

LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
VectorType resType = op.getResultVectorType();
if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
return failure();

auto loc = op.getLoc();

VectorType lhsType = op.getOperandVectorTypeLHS();
VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
VectorType resType = op.getResultVectorType();
Type eltType = resType.getElementType();
bool isInt = isa<IntegerType, IndexType>(eltType);
Value acc = op.getAcc();
Expand Down
107 changes: 106 additions & 1 deletion mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s

//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @transfer_write_2d_zero_i8(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
Expand Down Expand Up @@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
return
}

//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//

// -----

// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
Expand Down Expand Up @@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
return %tile : vector<[1]x[1]xi128>
}

//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @vector_store_i8(
Expand Down Expand Up @@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}

//===----------------------------------------------------------------------===//
// vector.outerproduct
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @vector_outerproduct_add_f16
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
// CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
// CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_bf16
func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_f32
func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
// CHECK-NOT: arith.extui
// CHECK-NOT: arith.trunci
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_f64
func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
// CHECK: arith.trunci {{.*}} : i64 to i32
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_no_accumulator
func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
// CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
%0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
// CHECK-NOT: arm_sme
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
return %0 : vector<[2]xf64>
}

// -----

func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{unsupported type}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
}

// -----

func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{unsupported kind}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{masking is currently unsupported}}
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
Loading

0 comments on commit a2e12e8

Please sign in to comment.