diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index ed174699314e8..2a4327535c687 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter { Type convertMemRefToBarePtr(BaseMemRefType type) const; /// Convert a 1D vector type into an LLVM vector type. - Type convertVectorType(VectorType type) const; + FailureOr convertVectorType(VectorType type) const; /// Options for customizing the llvm lowering. LowerToLLVMOptions options; diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h index 554b9f1192306..9e8ad48b3c2db 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h @@ -20,6 +20,8 @@ namespace mlir { namespace arm_sme { +constexpr unsigned MinStreamingVectorLengthInBits = 128; + /// Return minimum number of elements for the given element `type` in /// a vector of SVL bits. unsigned getSMETileSliceMinNumElts(Type type); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index a9e7ce9d42848..49e0513e629d9 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); - addConversion([&](VectorType type) { return convertVectorType(type); }); + addConversion([&](VectorType type) -> std::optional { + FailureOr llvmType = convertVectorType(type); + if (failed(llvmType)) + return std::nullopt; + return llvmType; + }); // LLVM-compatible types are legal, so add a pass-through conversion. Do this // before the conversions below since conversions are attempted in reverse @@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const { /// * 1-D `vector` remains as is while, /// * n>1 `vector` convert via an (n-1)-D array type to /// `!llvm.array>>`. -/// As LLVM does not support arrays of scalable vectors, it is assumed that -/// scalable vectors are always 1-D. This condition could be relaxed once the -/// missing functionality is added in LLVM -Type LLVMTypeConverter::convertVectorType(VectorType type) const { +/// Returns failure for n-D scalable vector types as LLVM does not support +/// arrays of scalable vectors. +FailureOr LLVMTypeConverter::convertVectorType(VectorType type) const { auto elementType = convertType(type.getElementType()); if (!elementType) return {}; @@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const { type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); - assert( - (!type.isScalable() || (type.getRank() == 1)) && - "expected 1-D scalable vector (n-D scalable vectors are not supported)"); + if (type.isScalable() && (type.getRank() > 1)) + return failure(); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index 685f8d57f76f5..6c8843fbb4546 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -361,6 +361,112 @@ struct MoveVectorToTileSliceToArmSMELowering } }; +/// Lower `vector.outerproduct` to SME MOPA intrinsics. +/// +/// Example: +/// +/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} +/// : vector<[4]xf32>, vector<[4]xf32> +/// +/// is converted to: +/// +/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs) +/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, +/// vector<[4]xf32>) -> () +/// +/// Currently only supports FMOPA and BFMOPA (non-widening). +struct VectorOuterProductToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::OuterProductOp outerProductOp, + vector::OuterProductOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto isSupportedType = [](VectorType vectorType) { + // TODO: the FP outer product instruction variants are predicated on + // different features [1]: + // + // * FMOPA (non-widening) + // * half-precision - +sme2p1,+sme-f16f16 + // * single-precision - +sme + // * double-precision - +sme-f64f64 + // * BFMOPA + // * half-precision - +sme2p1,+b16b16 + // + // It should be possible to control lowering based on target features. + // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile + if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable()) + return false; + + auto elementType = vectorType.getElementType(); + + if (!elementType.isF16() && !elementType.isBF16() && + !elementType.isF32() && !elementType.isF64()) + return false; + + unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits / + vectorType.getElementTypeBitWidth(); + if (vectorType.getShape() != ArrayRef({minNumElts, minNumElts})) + return false; + + return true; + }; + + auto resultVectorType = outerProductOp.getResultVectorType(); + if (!isSupportedType(resultVectorType)) + return outerProductOp.emitError("unsupported type"); + + vector::CombiningKind kind = outerProductOp.getKind(); + if (kind != vector::CombiningKind::ADD) + // TODO: support subtract. + return outerProductOp.emitError("unsupported kind"); + + auto maskableOp = + cast(outerProductOp.getOperation()); + if (maskableOp.isMasked()) + // TODO: support masking. + return outerProductOp.emitError("masking is currently unsupported"); + + if (!isa(outerProductOp.getOperandTypeRHS())) + // AXPY operation not suited for SME. + return failure(); + + auto loc = outerProductOp.getLoc(); + + Value acc = outerProductOp.getAcc(); + if (!acc) + // Initalize accumulator with zero. + acc = rewriter.create(loc, resultVectorType); + + unsigned elementWidth = resultVectorType.getElementTypeBitWidth(); + auto tileId = rewriter.create( + loc, rewriter.getIntegerType(elementWidth), acc); + + // Create all active predicate mask. + auto one = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto predTy = + VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(), + /*scalableDims=*/{true}); + auto allActiveMask = rewriter.create(loc, predTy, one); + + auto tileI32 = castTileIDToI32(tileId, loc, rewriter); + + // Create 'arm_sme.intr.mopa' outer product intrinsic. + rewriter.create( + loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(), + outerProductOp.getRhs()); + + // Create `CastTileToVectorOp` to use as the output. + rewriter.replaceOpWithNewOp( + outerProductOp, resultVectorType, tileId); + + return success(); + } +}; + } // namespace void mlir::configureArmSMELegalizeForExportTarget( @@ -374,8 +480,10 @@ void mlir::configureArmSMELegalizeForExportTarget( arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz, - arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>(); + arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable, + arm_sme::aarch64_sme_za_disable>(); target.addLegalOp(); + target.addIllegalOp(); // Mark 'func.func' ops as legal if either: // 1. no 'arm_za' function attribute is present. @@ -405,7 +513,8 @@ void mlir::configureArmSMELegalizeForExportTarget( void mlir::populateArmSMELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(patterns.getContext()); - patterns.add(converter); + patterns + .add(converter); } diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp index 8b2be7bc1901b..b8a47951cc7bb 100644 --- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp @@ -17,8 +17,6 @@ using namespace mlir; using namespace mlir::arm_sme; -static constexpr unsigned MinStreamingVectorLengthInBits = 128; - unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) { assert(isValidSMETileElementType(type) && "invalid tile type!"); return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index b66077372164e..95a010dd59d95 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1121,11 +1121,14 @@ class OuterProductOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { + VectorType resType = op.getResultVectorType(); + if ((resType.getShape().size() >= 2) && resType.allDimsScalable()) + return failure(); + auto loc = op.getLoc(); VectorType lhsType = op.getOperandVectorTypeLHS(); VectorType rhsType = dyn_cast(op.getOperandTypeRHS()); - VectorType resType = op.getResultVectorType(); Type eltType = resType.getElementType(); bool isInt = isa(eltType); Value acc = op.getAcc(); diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir index af528295ef6ee..687ef79385334 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -1,4 +1,8 @@ -// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s + +//===----------------------------------------------------------------------===// +// vector.transfer_write +//===----------------------------------------------------------------------===// // CHECK-LABEL: @transfer_write_2d_zero_i8( // CHECK-SAME: %[[ARG0:.*]]: memref) @@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref) { return } +//===----------------------------------------------------------------------===// +// vector.load +//===----------------------------------------------------------------------===// + // ----- // Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first @@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref) -> vector<[1]x[1]xi128> { return %tile : vector<[1]x[1]xi128> } +//===----------------------------------------------------------------------===// +// vector.store +//===----------------------------------------------------------------------===// + // ----- // CHECK-LABEL: @vector_store_i8( @@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref, vector<[1]x[1]xi128> return } + +//===----------------------------------------------------------------------===// +// vector.outerproduct +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @vector_outerproduct_add_f16 +// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>) +func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) { + // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[8]xi1> + // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16 + // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 + // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[8]xf16>, vector<[8]xf16> + "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_add_bf16 +func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) { + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[8]xbf16>, vector<[8]xbf16> + "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_add_f32 +func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) { + // CHECK-NOT: arith.extui + // CHECK-NOT: arith.trunci + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[4]xf32>, vector<[4]xf32> + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_add_f64 +func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) { + // CHECK: arith.trunci {{.*}} : i64 to i32 + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, vector<[2]xf64> + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_no_accumulator +func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) { + // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> () + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) + %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind} : vector<[2]xf64>, vector<[2]xf64> + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_unsupported_axpy +func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> { + // CHECK-NOT: arm_sme + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, f64 + return %0 : vector<[2]xf64> +} + +// ----- + +func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) { + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} + // expected-error@+1 {{unsupported type}} + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[16]xi8>, vector<[16]xi8> + "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () +} + +// ----- + +func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) { + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} + // expected-error@+1 {{unsupported kind}} + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, vector<[2]xf64> + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () +} + +// ----- + +func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) { + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} + // expected-error@+1 {{masking is currently unsupported}} + %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir new file mode 100644 index 0000000000000..00f1f6fd3fa8e --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir @@ -0,0 +1,116 @@ +// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32 +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -enable-arm-streaming="mode=locally enable-za" \ +// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{run} = %mcr_aarch64_cmd \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme \ +// DEFINE: -e %{entry_point} -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils + +// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITHOUT-ACC + +// REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_4x4xf32 +// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC + +llvm.func @printCString(!llvm.ptr) + +func.func @printTileBegin() { + %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @printTileEnd() { + %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @test_outerproduct_no_accumulator_4x4xf32() { + %c0 = arith.constant 0 : index + + %vector_i32 = llvm.intr.experimental.stepvector : vector<[4]xi32> + %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32> + %tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32> + + // Calculate the size of a 32-bit tile, e.g. ZA{n}.s. + %vscale = vector.vscale + %min_elts_s = arith.constant 4 : index + %svl_s = arith.muli %min_elts_s, %vscale : index + %za_s_size = arith.muli %svl_s, %svl_s : index + + // Allocate memory. + %mem = memref.alloca(%za_s_size) : memref + + // Store the tile to memory. + vector.store %tile, %mem[%c0] : memref, vector<[4]x[4]xf32> + + // Reload and print. The smallest SVL is 128-bits so the tile will be at + // least 4x4xf32. + // + // WITHOUT-ACC: TILE BEGIN + // WITHOUT-ACC-NEXT: ( 0, 0, 0, 0 + // WITHOUT-ACC-NEXT: ( 0, 1, 2, 3 + // WITHOUT-ACC-NEXT: ( 0, 2, 4, 6 + // WITHOUT-ACC-NEXT: ( 0, 3, 6, 9 + // WITHOUT-ACC: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %za_s_size step %svl_s { + %tileslice = vector.load %mem[%i] : memref, vector<[4]xf32> + vector.print %tileslice : vector<[4]xf32> + } + func.call @printTileEnd() : () -> () + + return +} + +func.func @test_outerproduct_with_accumulator_4x4xf32() { + %c0 = arith.constant 0 : index + %f10 = arith.constant 10.0 : f32 + + %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32> + %vector_i32 = llvm.intr.experimental.stepvector : vector<[4]xi32> + %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32> + %tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32> + + // Calculate the size of a 32-bit tile, e.g. ZA{n}.s. + %vscale = vector.vscale + %min_elts_s = arith.constant 4 : index + %svl_s = arith.muli %min_elts_s, %vscale : index + %za_s_size = arith.muli %svl_s, %svl_s : index + + // Allocate memory. + %mem = memref.alloca(%za_s_size) : memref + + // Store the tile to memory. + vector.store %tile, %mem[%c0] : memref, vector<[4]x[4]xf32> + + // Reload and print. The smallest SVL is 128-bits so the tile will be at + // least 4x4xf32. + // + // WITH-ACC: TILE BEGIN + // WITH-ACC-NEXT: ( 10, 10, 10, 10 + // WITH-ACC-NEXT: ( 10, 11, 12, 13 + // WITH-ACC-NEXT: ( 10, 12, 14, 16 + // WITH-ACC-NEXT: ( 10, 13, 16, 19 + // WITH-ACC: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %za_s_size step %svl_s { + %tileslice = vector.load %mem[%i] : memref, vector<[4]xf32> + vector.print %tileslice : vector<[4]xf32> + } + func.call @printTileEnd() : () -> () + + return +} + +llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A") +llvm.mlir.global internal constant @str_tile_end("TILE END\0A") diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir new file mode 100644 index 0000000000000..2c2a06fa8db26 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir @@ -0,0 +1,77 @@ +// DEFINE: %{entry_point} = test_outerproduct_with_accumulator_2x2xf64 +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -enable-arm-streaming="mode=locally enable-za" \ +// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{run} = %mcr_aarch64_cmd \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme-f64f64 \ +// DEFINE: -e %{entry_point} -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils + +// RUN: %{compile} | %{run} | FileCheck %s + +llvm.func @printCString(!llvm.ptr) + +func.func @printTileBegin() { + %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @printTileEnd() { + %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @test_outerproduct_with_accumulator_2x2xf64() { + %c0 = arith.constant 0 : index + %f1 = arith.constant 1.0 : f64 + %f2 = arith.constant 2.0 : f64 + %f10 = arith.constant 10.0 : f64 + + %a = vector.splat %f1 : vector<[2]xf64> + %b = vector.splat %f2 : vector<[2]xf64> + // TODO: vector.splat doesn't support ArmSME. + %c = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64> + + %tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64> + + // Calculate the size of a 64-bit tile, e.g. ZA{n}.d. + %vscale = vector.vscale + %min_elts_d = arith.constant 2 : index + %svl_d = arith.muli %min_elts_d, %vscale : index + %za_d_size = arith.muli %svl_d, %svl_d : index + + // Allocate memory. + %mem = memref.alloca(%za_d_size) : memref + + // Store the tile to memory. + vector.store %tile, %mem[%c0] : memref, vector<[2]x[2]xf64> + + // Reload and print. The smallest SVL is 128-bits so the tile will be at + // least 2x2xf64. + // + // CHECK: TILE BEGIN + // CHECK-NEXT: ( 12, 12 + // CHECK-NEXT: ( 12, 12 + // CHECK: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %za_d_size step %svl_d { + %tileslice = vector.load %mem[%i] : memref, vector<[2]xf64> + vector.print %tileslice : vector<[2]xf64> + } + func.call @printTileEnd() : () -> () + + return +} + +llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A") +llvm.mlir.global internal constant @str_tile_end("TILE END\0A")