Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA #65621

Merged
merged 4 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter {
Type convertMemRefToBarePtr(BaseMemRefType type) const;

/// Convert a 1D vector type into an LLVM vector type.
Type convertVectorType(VectorType type) const;
FailureOr<Type> convertVectorType(VectorType type) const;

/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
namespace mlir {
namespace arm_sme {

constexpr unsigned MinStreamingVectorLengthInBits = 128;

/// Return minimum number of elements for the given element `type` in
/// a vector of SVL bits.
unsigned getSMETileSliceMinNumElts(Type type);
Expand Down
19 changes: 11 additions & 8 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addConversion([&](MemRefType type) { return convertMemRefType(type); });
addConversion(
[&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
addConversion([&](VectorType type) { return convertVectorType(type); });
addConversion([&](VectorType type) -> std::optional<Type> {
FailureOr<Type> llvmType = convertVectorType(type);
if (failed(llvmType))
return std::nullopt;
return llvmType;
});

// LLVM-compatible types are legal, so add a pass-through conversion. Do this
// before the conversions below since conversions are attempted in reverse
Expand Down Expand Up @@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
/// * 1-D `vector<axT>` remains as is while,
/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
/// `!llvm.array<ax...array<jxvector<kxT>>>`.
/// As LLVM does not support arrays of scalable vectors, it is assumed that
/// scalable vectors are always 1-D. This condition could be relaxed once the
/// missing functionality is added in LLVM
Type LLVMTypeConverter::convertVectorType(VectorType type) const {
/// Returns failure for n-D scalable vector types as LLVM does not support
/// arrays of scalable vectors.
FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
auto elementType = convertType(type.getElementType());
if (!elementType)
return {};
Expand All @@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
assert(
(!type.isScalable() || (type.getRank() == 1)) &&
"expected 1-D scalable vector (n-D scalable vectors are not supported)");
if (type.isScalable() && (type.getRank() > 1))
return failure();
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
Expand Down
116 changes: 112 additions & 4 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,111 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};

/// Lower `vector.outerproduct` to SME MOPA intrinsics.
///
/// Example:
///
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
/// : vector<[4]xf32>, vector<[4]xf32>
///
/// is converted to:
///
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two have the same name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.

How do we plan to implement proper support for this? I see two options:

  1. In one shot, we search for the two masks in the use-def chain and use them directly in the intrinsic. If there is any mask manipulation operation in-between, it should become dead, hopefully, and go away.
  2. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.

Thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two have the same name?

If you're referring to the predicate (?) that's because they're the same, both all active

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This intrinsic takes two operands to mask the inputs in the op itself so to support masking you would have to propagate the mask from the producer ops... That's interesting because it looks like the op knows how to merge both masks without requiring independent mask manipulation operations.

How do we plan to implement proper support for this? I see two options:

1. In one shot, we search for the two masks in the use-def chain and use them directly in the intrinsic. If there is any mask manipulation operation in-between, it should become dead, hopefully, and go away.

2. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I think doing all of that as part of the lowering (1) might be too much for a lowering, esp. if finding the masks through the use-def chain is not trivial. (2) seems simpler to me but I wouldn't implement that on top of an llvm intrinsic. I think for that we should have a proper sme op, which should be fine now that we have the sme dialect.

Thoughts?

The only examples I've seen of masking (from grepping around the codebase) are where the mask is applied to the result of the outerproduct e.g. vector.mask { vector.outerproduct ... }, I just figured we'd need some way to correlate this to the inputs, but hadn't given it much thought.

Appreciate your input, I'll add a custom op that way there's more flexibility when it comes to masking, and will also look into how it would be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the expected complexity with generating correct masks, I am also leaning towards a custom op. Having said that, IMHO this PR is fine as is and we could iterate in the follow-up patches.

  1. In two steps, we pass the single mask in the masked vector outerproduct operation to both operands and later run a pass that replace this mask with the two masks from the operands, again.

I guess that for this to work, we'd need something like `

 %res = arm_sme.op %rhs, %lhs <optional_mask_for_rhs_or_result> <optional_mask_for_lhs>

So, we'd allow 2 optional masks, both of which would be optional:

  • if only 1 mask is specified then this is a mask for the result (1 x 2D),
  • if 2 masks are specified then these are for the input vectors 2 x 1D),
  • if no masks are specified, then use ptrue (all lanes are active).

WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We perhaps could use vector.mask for masking the result so that we don't have to disambiguate the semantics based on the number of masks...

/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
/// vector<[4]xf32>) -> ()
///
/// Currently only supports FMOPA and BFMOPA (non-widening).
struct VectorOuterProductToArmSMELowering
: public ConvertOpToLLVMPattern<vector::OuterProductOp> {
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::OuterProductOp outerProductOp,
vector::OuterProductOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto isSupportedType = [](VectorType vectorType) {
// TODO: the FP outer product instruction variants are predicated on
// different features:
//
// * FMOPA (non-widening)
// * half-precision - +sme2p1,+sme-f16f16
// * single-precision - +sme
// * double-precision - +sme-f64f64
// * BFMOPA
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;

auto elementType = vectorType.getElementType();

if (!elementType.isF16() && !elementType.isBF16() &&
!elementType.isF32() && !elementType.isF64())
return false;

unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
vectorType.getElementTypeBitWidth();
if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
return false;

return true;
};

auto resultVectorType = outerProductOp.getResultVectorType();
if (!isSupportedType(resultVectorType))
return outerProductOp.emitError("unsupported type");

vector::CombiningKind kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
// TODO: support subtract.
return outerProductOp.emitError("unsupported kind");

auto maskableOp =
cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
if (maskableOp.isMasked())
// TODO: support masking.
return outerProductOp.emitError("masking is currently unsupported");

if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
// AXPY operation not suited for SME.
return failure();

auto loc = outerProductOp.getLoc();

Value acc = outerProductOp.getAcc();
if (!acc)
// Initalize accumulator with zero.
acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);

unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(elementWidth), acc);

// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a ConstantMaskOp here

Copy link
Member

@MacDue MacDue Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConstantMaskOp can only make all false masks for scalable vectors.

// Only zero sizes are accepted here:
vector.constant_mask [0] : vector<[4]xi1>

Could maybe use CreateMaskOp, but I'm not sure if it's much simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like a fairly important gap for us to fill, but not necessarily in this patch.

loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy =
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);

auto tileI32 = castTileIDToI32(tileId, loc, rewriter);

// Create 'arm_sme.intr.mopa' outer product intrinsic.
rewriter.create<arm_sme::aarch64_sme_mopa>(
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
outerProductOp.getRhs());

// Create `CastTileToVectorOp` to use as the output.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
outerProductOp, resultVectorType, tileId);

return success();
}
};

} // namespace

void mlir::configureArmSMELegalizeForExportTarget(
Expand All @@ -374,8 +479,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();

// Mark 'func.func' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
Expand Down Expand Up @@ -405,7 +512,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
LoadTileSliceToArmSMELowering,
MoveVectorToTileSliceToArmSMELowering>(converter);
patterns
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
VectorOuterProductToArmSMELowering>(converter);
}
2 changes: 0 additions & 2 deletions mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
using namespace mlir;
using namespace mlir::arm_sme;

static constexpr unsigned MinStreamingVectorLengthInBits = 128;

unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) && "invalid tile type!");
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1121,11 +1121,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {

LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
VectorType resType = op.getResultVectorType();
if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
return failure();

auto loc = op.getLoc();

VectorType lhsType = op.getOperandVectorTypeLHS();
VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
VectorType resType = op.getResultVectorType();
Type eltType = resType.getElementType();
bool isInt = isa<IntegerType, IndexType>(eltType);
Value acc = op.getAcc();
Expand Down
107 changes: 106 additions & 1 deletion mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s

//===----------------------------------------------------------------------===//
// vector.transfer_write
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @transfer_write_2d_zero_i8(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
Expand Down Expand Up @@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
return
}

//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//

// -----

// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
Expand Down Expand Up @@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
return %tile : vector<[1]x[1]xi128>
}

//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @vector_store_i8(
Expand Down Expand Up @@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}

//===----------------------------------------------------------------------===//
// vector.outerproduct
//===----------------------------------------------------------------------===//

// -----

// CHECK-LABEL: @vector_outerproduct_add_f16
// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
// CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
// CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_bf16
func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_f32
func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
// CHECK-NOT: arith.extui
// CHECK-NOT: arith.trunci
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_add_f64
func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
// CHECK: arith.trunci {{.*}} : i64 to i32
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_no_accumulator
func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
// CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
// CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
%0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

// CHECK-LABEL: @vector_outerproduct_scalar_rhs
func.func @vector_outerproduct_scalar_rhs(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
// CHECK-NOT: arm_sme
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
return %0 : vector<[2]xf64>
}

// -----

func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{unsupported type}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
}

// -----

func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{unsupported kind}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}

// -----

func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
// expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}}
// expected-error@+1 {{masking is currently unsupported}}
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
Loading