diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 005dd546bf1632..5491f7dd30629a 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -60,15 +60,30 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc, namespace { -/// Conversion pattern for vector.transfer_read op with transpose permutation -/// map to vertical arm_sme.tile_load (in-flight transpose). +/// Conversion pattern for vector.transfer_read. +/// +/// --- +/// +/// Example 1: op with identity permutation map to horizontal +/// arm_sme.tile_load: +/// +/// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) +/// +/// is converted to: +/// +/// arm_sme.tile_load ... +/// +/// --- +/// +/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load +/// (in-flight transpose): /// /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) /// /// is converted to: /// /// arm_sme.tile_load ... layout -struct TransferReadPermutationToArmSMELowering +struct TransferReadToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -79,15 +94,6 @@ struct TransferReadPermutationToArmSMELowering return rewriter.notifyMatchFailure(transferReadOp, "not a 2 result permutation map"); - AffineMap map = transferReadOp.getPermutationMap(); - - // Permutation map doesn't perform permutation, can be lowered to - // vector.load by TransferReadToVectorLoadLowering and then - // arm_sme.tile_load by VectorLoadToArmSMELowering. - if (map.isIdentity()) - return rewriter.notifyMatchFailure( - transferReadOp, "map is an identity, apply another pattern"); - auto vectorType = transferReadOp.getVectorType(); if (!arm_sme::isValidSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(transferReadOp, @@ -96,26 +102,33 @@ struct TransferReadPermutationToArmSMELowering if (!llvm::isa(transferReadOp.getSource().getType())) return rewriter.notifyMatchFailure(transferReadOp, "not a memref source"); - if (transferReadOp.getMask()) - // TODO: support masking. - return rewriter.notifyMatchFailure(transferReadOp, - "masking not yet supported"); - // Out-of-bounds dims are not supported. if (transferReadOp.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(transferReadOp, "not inbounds transfer read"); + arm_sme::TileSliceLayout layout; + AffineExpr d0, d1; bindDims(transferReadOp.getContext(), d0, d1); - if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0}, - transferReadOp.getContext())) + AffineMap map = transferReadOp.getPermutationMap(); + if (map.isIdentity()) + layout = arm_sme::TileSliceLayout::Horizontal; + else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, + transferReadOp.getContext())) + layout = arm_sme::TileSliceLayout::Vertical; + else return rewriter.notifyMatchFailure(transferReadOp, - "not true 2-D matrix transpose"); + "unsupported permutation map"); + // Padding isn't optional for transfer_read, but is only used in the case + // of out-of-bounds accesses (not supported here) and/or masking. Mask is + // optional, if it's not present don't pass padding. + auto mask = transferReadOp.getMask(); + auto padding = mask ? transferReadOp.getPadding() : nullptr; rewriter.replaceOpWithNewOp( transferReadOp, vectorType, transferReadOp.getSource(), - transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical); + transferReadOp.getIndices(), padding, mask, layout); return success(); } @@ -531,7 +544,7 @@ struct VectorOuterProductToArmSMELowering void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { patterns.add(&ctx); diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir index 5f41313fc6ac78..ed33f8508dba0b 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -1,181 +1,189 @@ // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s //===----------------------------------------------------------------------===// -// vector.transfer_read (with in-flight transpose) +// vector.transfer_read //===----------------------------------------------------------------------===// -// CHECK-LABEL: @transfer_read_2d_transpose_i8 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[16]x[16]xi8> -func.func @transfer_read_2d_transpose_i8(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i8 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[16]x[16]xi8> +func.func @transfer_read_2d_i8(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i8 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[16]x[16]xi8> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[16]x[16]xi8> "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_i16 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xi16> -func.func @transfer_read_2d_transpose_i16(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i16 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xi16> +func.func @transfer_read_2d_i16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i16 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[8]x[8]xi16> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[8]x[8]xi16> "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_i32 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xi32> -func.func @transfer_read_2d_transpose_i32(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i32 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[4]x[4]xi32> +func.func @transfer_read_2d_i32(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i32 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[4]x[4]xi32> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[4]x[4]xi32> "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_i64 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xi64> -func.func @transfer_read_2d_transpose_i64(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i64 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[2]x[2]xi64> +func.func @transfer_read_2d_i64(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[2]x[2]xi64> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[2]x[2]xi64> "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_i128 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[1]x[1]xi128> -func.func @transfer_read_2d_transpose_i128(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_i128 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[1]x[1]xi128> +func.func @transfer_read_2d_i128(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0 : i128 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[1]x[1]xi128> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[1]x[1]xi128> "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_f16 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xf16> -func.func @transfer_read_2d_transpose_f16(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_f16 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xf16> +func.func @transfer_read_2d_f16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f16 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[8]x[8]xf16> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[8]x[8]xf16> "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_bf16 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[8]x[8]xbf16> -func.func @transfer_read_2d_transpose_bf16(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_bf16 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[8]x[8]xbf16> +func.func @transfer_read_2d_bf16(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : bf16 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[8]x[8]xbf16> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[8]x[8]xbf16> "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_f32 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xf32> -func.func @transfer_read_2d_transpose_f32(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_f32 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[4]x[4]xf32> +func.func @transfer_read_2d_f32(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f32 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[4]x[4]xf32> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[4]x[4]xf32> "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d_transpose_f64 -// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[2]x[2]xf64> -func.func @transfer_read_2d_transpose_f64(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_f64 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref, vector<[2]x[2]xf64> +func.func @transfer_read_2d_f64(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[2]x[2]xf64> + %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[2]x[2]xf64> "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d__bad_type -// CHECK-NOT: arm_sme.tile_load -// CHECK: vector.transfer_read -func.func @transfer_read_2d__bad_type(%src : memref) { +// CHECK-LABEL: @transfer_read_2d_with_mask_i16 +// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref, vector<[8]x[8]xi16> +func.func @transfer_read_2d_with_mask_i16(%src : memref, %mask : vector<[8]x[8]xi1>) { %c0 = arith.constant 0 : index - %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref, vector<[4]x[4]xf64> - "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> () + %pad = arith.constant 0 : i16 + %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref, vector<[8]x[8]xi16> + "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d__non_memref_type -// CHECK-NOT: arm_sme.tile_load -// CHECK: vector.transfer_read -func.func @transfer_read_2d__non_memref_type(%src : tensor) { +/// in-flight transpose + +// CHECK-LABEL: @transfer_read_2d_transpose_i8 +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[16]x[16]xi8> +func.func @transfer_read_2d_transpose_i8(%src : memref) { %c0 = arith.constant 0 : index - %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor, vector<[2]x[2]xf64> - "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () + %pad = arith.constant 0 : i8 + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[16]x[16]xi8> + "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank +// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32 +// CHECK: arm_sme.tile_load {{.*}} layout : memref, vector<[4]x[4]xf32> +func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref, %mask : vector<[4]x[4]xi1>) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[4]x[4]xf32> + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: @transfer_read_2d__bad_type // CHECK-NOT: arm_sme.tile_load // CHECK: vector.transfer_read -func.func @transfer_read_2d__bad_transfer_rank(%src : memref) { +func.func @transfer_read_2d__bad_type(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref, vector<[2]xf64> - "prevent.dce"(%0) : (vector<[2]xf64>) -> () + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref, vector<[4]x[4]xf64> + "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> () return } // ----- -// CHECK-LABEL: @transfer_read_2d__unsupported_mask +// CHECK-LABEL: @transfer_read_2d__non_memref_type // CHECK-NOT: arm_sme.tile_load // CHECK: vector.transfer_read -func.func @transfer_read_2d__unsupported_mask(%src : memref, %mask : vector<[2]x[2]xi1>) { +func.func @transfer_read_2d__non_memref_type(%src : tensor) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref, vector<[2]x[2]xf64> + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor, vector<[2]x[2]xf64> "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () return } // ----- -/// transfer_read with identity map should be lowered to vector.load by -/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by -/// VectorLoadToArmSMELowering. - -// CHECK-LABEL: @transfer_read_2d__non_permuting_map +// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank // CHECK-NOT: arm_sme.tile_load // CHECK: vector.transfer_read -func.func @transfer_read_2d__non_permuting_map(%src : memref) { +func.func @transfer_read_2d__bad_transfer_rank(%src : memref) { %c0 = arith.constant 0 : index %pad = arith.constant 0.0 : f64 - %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref, vector<[2]x[2]xf64> - "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () + %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref, vector<[2]xf64> + "prevent.dce"(%0) : (vector<[2]xf64>) -> () return }