From e566113fa4022a06b973b811e3305a34a47bba8f Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Thu, 14 Sep 2023 08:31:52 +0100 Subject: [PATCH] [mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA (#65621) This patch adds support for lowering vector.outerproduct to the ArmSME MOPA intrinsic for the following types: vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16> vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16> vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32> vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64> The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA (non-widening) [2]. Note at the ISA level these variants are implemented by different architecture features, these are listed below: FMOPA (non-widening) * half-precision - +sme2p1,+sme-f16f16 * single-precision - +sme * double-precision - +sme-f64f64 BFMOPA (non-widening) * half-precision - +sme2p1,+b16b16 There's currently no way to target different features when lowering to ArmSME. Integration tests are added for F32 and F64. We use QEMU to run the integration tests but SME2 support isn't available yet, it's targeted for 9.0, so integration tests for these variants excluded. Masking is currently unsupported. Depends on #65450. [1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate- [2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate- --- .../include/mlir/Dialect/ArmSME/Utils/Utils.h | 2 + .../Transforms/LegalizeForLLVMExport.cpp | 117 +++++++++++++++++- mlir/lib/Dialect/ArmSME/Utils/Utils.cpp | 2 - .../Vector/Transforms/LowerVectorContract.cpp | 5 +- .../Dialect/ArmSME/vector-ops-to-llvm.mlir | 107 +++++++++++++++- .../CPU/ArmSME/test-outerproduct-f32.mlir | 116 +++++++++++++++++ .../CPU/ArmSME/test-outerproduct-f64.mlir | 77 ++++++++++++ 7 files changed, 418 insertions(+), 8 deletions(-) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h index 554b9f1192306..9e8ad48b3c2db 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h @@ -20,6 +20,8 @@ namespace mlir { namespace arm_sme { +constexpr unsigned MinStreamingVectorLengthInBits = 128; + /// Return minimum number of elements for the given element `type` in /// a vector of SVL bits. unsigned getSMETileSliceMinNumElts(Type type); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index 685f8d57f76f5..6c8843fbb4546 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -361,6 +361,112 @@ struct MoveVectorToTileSliceToArmSMELowering } }; +/// Lower `vector.outerproduct` to SME MOPA intrinsics. +/// +/// Example: +/// +/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} +/// : vector<[4]xf32>, vector<[4]xf32> +/// +/// is converted to: +/// +/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs) +/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, +/// vector<[4]xf32>) -> () +/// +/// Currently only supports FMOPA and BFMOPA (non-widening). +struct VectorOuterProductToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::OuterProductOp outerProductOp, + vector::OuterProductOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto isSupportedType = [](VectorType vectorType) { + // TODO: the FP outer product instruction variants are predicated on + // different features [1]: + // + // * FMOPA (non-widening) + // * half-precision - +sme2p1,+sme-f16f16 + // * single-precision - +sme + // * double-precision - +sme-f64f64 + // * BFMOPA + // * half-precision - +sme2p1,+b16b16 + // + // It should be possible to control lowering based on target features. + // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile + if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable()) + return false; + + auto elementType = vectorType.getElementType(); + + if (!elementType.isF16() && !elementType.isBF16() && + !elementType.isF32() && !elementType.isF64()) + return false; + + unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits / + vectorType.getElementTypeBitWidth(); + if (vectorType.getShape() != ArrayRef({minNumElts, minNumElts})) + return false; + + return true; + }; + + auto resultVectorType = outerProductOp.getResultVectorType(); + if (!isSupportedType(resultVectorType)) + return outerProductOp.emitError("unsupported type"); + + vector::CombiningKind kind = outerProductOp.getKind(); + if (kind != vector::CombiningKind::ADD) + // TODO: support subtract. + return outerProductOp.emitError("unsupported kind"); + + auto maskableOp = + cast(outerProductOp.getOperation()); + if (maskableOp.isMasked()) + // TODO: support masking. + return outerProductOp.emitError("masking is currently unsupported"); + + if (!isa(outerProductOp.getOperandTypeRHS())) + // AXPY operation not suited for SME. + return failure(); + + auto loc = outerProductOp.getLoc(); + + Value acc = outerProductOp.getAcc(); + if (!acc) + // Initalize accumulator with zero. + acc = rewriter.create(loc, resultVectorType); + + unsigned elementWidth = resultVectorType.getElementTypeBitWidth(); + auto tileId = rewriter.create( + loc, rewriter.getIntegerType(elementWidth), acc); + + // Create all active predicate mask. + auto one = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto predTy = + VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(), + /*scalableDims=*/{true}); + auto allActiveMask = rewriter.create(loc, predTy, one); + + auto tileI32 = castTileIDToI32(tileId, loc, rewriter); + + // Create 'arm_sme.intr.mopa' outer product intrinsic. + rewriter.create( + loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(), + outerProductOp.getRhs()); + + // Create `CastTileToVectorOp` to use as the output. + rewriter.replaceOpWithNewOp( + outerProductOp, resultVectorType, tileId); + + return success(); + } +}; + } // namespace void mlir::configureArmSMELegalizeForExportTarget( @@ -374,8 +480,10 @@ void mlir::configureArmSMELegalizeForExportTarget( arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz, - arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>(); + arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable, + arm_sme::aarch64_sme_za_disable>(); target.addLegalOp(); + target.addIllegalOp(); // Mark 'func.func' ops as legal if either: // 1. no 'arm_za' function attribute is present. @@ -405,7 +513,8 @@ void mlir::configureArmSMELegalizeForExportTarget( void mlir::populateArmSMELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(patterns.getContext()); - patterns.add(converter); + patterns + .add(converter); } diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp index 8b2be7bc1901b..b8a47951cc7bb 100644 --- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp @@ -17,8 +17,6 @@ using namespace mlir; using namespace mlir::arm_sme; -static constexpr unsigned MinStreamingVectorLengthInBits = 128; - unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) { assert(isValidSMETileElementType(type) && "invalid tile type!"); return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 3ab3c6ad8a3e2..64ab0abda26e6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1122,11 +1122,14 @@ class OuterProductOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { + VectorType resType = op.getResultVectorType(); + if ((resType.getShape().size() >= 2) && resType.allDimsScalable()) + return failure(); + auto loc = op.getLoc(); VectorType lhsType = op.getOperandVectorTypeLHS(); VectorType rhsType = dyn_cast(op.getOperandTypeRHS()); - VectorType resType = op.getResultVectorType(); Type eltType = resType.getElementType(); bool isInt = isa(eltType); Value acc = op.getAcc(); diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir index af528295ef6ee..687ef79385334 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -1,4 +1,8 @@ -// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s + +//===----------------------------------------------------------------------===// +// vector.transfer_write +//===----------------------------------------------------------------------===// // CHECK-LABEL: @transfer_write_2d_zero_i8( // CHECK-SAME: %[[ARG0:.*]]: memref) @@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref) { return } +//===----------------------------------------------------------------------===// +// vector.load +//===----------------------------------------------------------------------===// + // ----- // Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first @@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref) -> vector<[1]x[1]xi128> { return %tile : vector<[1]x[1]xi128> } +//===----------------------------------------------------------------------===// +// vector.store +//===----------------------------------------------------------------------===// + // ----- // CHECK-LABEL: @vector_store_i8( @@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref, vector<[1]x[1]xi128> return } + +//===----------------------------------------------------------------------===// +// vector.outerproduct +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @vector_outerproduct_add_f16 +// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>) +func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) { + // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[8]xi1> + // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16 + // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 + // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[8]xf16>, vector<[8]xf16> + "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_add_bf16 +func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) { + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[8]xbf16>, vector<[8]xbf16> + "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_add_f32 +func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) { + // CHECK-NOT: arith.extui + // CHECK-NOT: arith.trunci + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[4]xf32>, vector<[4]xf32> + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_add_f64 +func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) { + // CHECK: arith.trunci {{.*}} : i64 to i32 + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, vector<[2]xf64> + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_no_accumulator +func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) { + // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> () + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) + %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind} : vector<[2]xf64>, vector<[2]xf64> + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () +} + +// ----- + +// CHECK-LABEL: @vector_outerproduct_unsupported_axpy +func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> { + // CHECK-NOT: arm_sme + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, f64 + return %0 : vector<[2]xf64> +} + +// ----- + +func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) { + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} + // expected-error@+1 {{unsupported type}} + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[16]xi8>, vector<[16]xi8> + "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () +} + +// ----- + +func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) { + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} + // expected-error@+1 {{unsupported kind}} + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, vector<[2]xf64> + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () +} + +// ----- + +func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) { + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} + // expected-error@+1 {{masking is currently unsupported}} + %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir new file mode 100644 index 0000000000000..00f1f6fd3fa8e --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir @@ -0,0 +1,116 @@ +// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32 +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -enable-arm-streaming="mode=locally enable-za" \ +// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{run} = %mcr_aarch64_cmd \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme \ +// DEFINE: -e %{entry_point} -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils + +// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITHOUT-ACC + +// REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_4x4xf32 +// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC + +llvm.func @printCString(!llvm.ptr) + +func.func @printTileBegin() { + %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @printTileEnd() { + %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @test_outerproduct_no_accumulator_4x4xf32() { + %c0 = arith.constant 0 : index + + %vector_i32 = llvm.intr.experimental.stepvector : vector<[4]xi32> + %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32> + %tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32> + + // Calculate the size of a 32-bit tile, e.g. ZA{n}.s. + %vscale = vector.vscale + %min_elts_s = arith.constant 4 : index + %svl_s = arith.muli %min_elts_s, %vscale : index + %za_s_size = arith.muli %svl_s, %svl_s : index + + // Allocate memory. + %mem = memref.alloca(%za_s_size) : memref + + // Store the tile to memory. + vector.store %tile, %mem[%c0] : memref, vector<[4]x[4]xf32> + + // Reload and print. The smallest SVL is 128-bits so the tile will be at + // least 4x4xf32. + // + // WITHOUT-ACC: TILE BEGIN + // WITHOUT-ACC-NEXT: ( 0, 0, 0, 0 + // WITHOUT-ACC-NEXT: ( 0, 1, 2, 3 + // WITHOUT-ACC-NEXT: ( 0, 2, 4, 6 + // WITHOUT-ACC-NEXT: ( 0, 3, 6, 9 + // WITHOUT-ACC: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %za_s_size step %svl_s { + %tileslice = vector.load %mem[%i] : memref, vector<[4]xf32> + vector.print %tileslice : vector<[4]xf32> + } + func.call @printTileEnd() : () -> () + + return +} + +func.func @test_outerproduct_with_accumulator_4x4xf32() { + %c0 = arith.constant 0 : index + %f10 = arith.constant 10.0 : f32 + + %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32> + %vector_i32 = llvm.intr.experimental.stepvector : vector<[4]xi32> + %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32> + %tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32> + + // Calculate the size of a 32-bit tile, e.g. ZA{n}.s. + %vscale = vector.vscale + %min_elts_s = arith.constant 4 : index + %svl_s = arith.muli %min_elts_s, %vscale : index + %za_s_size = arith.muli %svl_s, %svl_s : index + + // Allocate memory. + %mem = memref.alloca(%za_s_size) : memref + + // Store the tile to memory. + vector.store %tile, %mem[%c0] : memref, vector<[4]x[4]xf32> + + // Reload and print. The smallest SVL is 128-bits so the tile will be at + // least 4x4xf32. + // + // WITH-ACC: TILE BEGIN + // WITH-ACC-NEXT: ( 10, 10, 10, 10 + // WITH-ACC-NEXT: ( 10, 11, 12, 13 + // WITH-ACC-NEXT: ( 10, 12, 14, 16 + // WITH-ACC-NEXT: ( 10, 13, 16, 19 + // WITH-ACC: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %za_s_size step %svl_s { + %tileslice = vector.load %mem[%i] : memref, vector<[4]xf32> + vector.print %tileslice : vector<[4]xf32> + } + func.call @printTileEnd() : () -> () + + return +} + +llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A") +llvm.mlir.global internal constant @str_tile_end("TILE END\0A") diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir new file mode 100644 index 0000000000000..2c2a06fa8db26 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir @@ -0,0 +1,77 @@ +// DEFINE: %{entry_point} = test_outerproduct_with_accumulator_2x2xf64 +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -enable-arm-streaming="mode=locally enable-za" \ +// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{run} = %mcr_aarch64_cmd \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme-f64f64 \ +// DEFINE: -e %{entry_point} -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils + +// RUN: %{compile} | %{run} | FileCheck %s + +llvm.func @printCString(!llvm.ptr) + +func.func @printTileBegin() { + %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @printTileEnd() { + %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @test_outerproduct_with_accumulator_2x2xf64() { + %c0 = arith.constant 0 : index + %f1 = arith.constant 1.0 : f64 + %f2 = arith.constant 2.0 : f64 + %f10 = arith.constant 10.0 : f64 + + %a = vector.splat %f1 : vector<[2]xf64> + %b = vector.splat %f2 : vector<[2]xf64> + // TODO: vector.splat doesn't support ArmSME. + %c = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64> + + %tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64> + + // Calculate the size of a 64-bit tile, e.g. ZA{n}.d. + %vscale = vector.vscale + %min_elts_d = arith.constant 2 : index + %svl_d = arith.muli %min_elts_d, %vscale : index + %za_d_size = arith.muli %svl_d, %svl_d : index + + // Allocate memory. + %mem = memref.alloca(%za_d_size) : memref + + // Store the tile to memory. + vector.store %tile, %mem[%c0] : memref, vector<[2]x[2]xf64> + + // Reload and print. The smallest SVL is 128-bits so the tile will be at + // least 2x2xf64. + // + // CHECK: TILE BEGIN + // CHECK-NEXT: ( 12, 12 + // CHECK-NEXT: ( 12, 12 + // CHECK: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %za_d_size step %svl_d { + %tileslice = vector.load %mem[%i] : memref, vector<[2]xf64> + vector.print %tileslice : vector<[2]xf64> + } + func.call @printTileEnd() : () -> () + + return +} + +llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A") +llvm.mlir.global internal constant @str_tile_end("TILE END\0A")