Skip to content

Commit

Permalink
[mlir][ArmSME] Propagate pad and mask in vector.transfer_read lowering (
Browse files Browse the repository at this point in the history
llvm#70814)

This extends the lowering of vector.transfer_read -> arm_sme.tile_load
lowering to propagate pad and mask.

The restriction on the transfer_read being a transposition is also
removed, identity maps are lowered to normal horizontal loads.
  • Loading branch information
c-rhodes authored Nov 2, 2023
1 parent c1b55ae commit 22f1159
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 88 deletions.
57 changes: 35 additions & 22 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,30 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,

namespace {

/// Conversion pattern for vector.transfer_read op with transpose permutation
/// map to vertical arm_sme.tile_load (in-flight transpose).
/// Conversion pattern for vector.transfer_read.
///
/// ---
///
/// Example 1: op with identity permutation map to horizontal
/// arm_sme.tile_load:
///
/// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1)
///
/// is converted to:
///
/// arm_sme.tile_load ...
///
/// ---
///
/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
/// (in-flight transpose):
///
/// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0)
///
/// is converted to:
///
/// arm_sme.tile_load ... layout<vertical>
struct TransferReadPermutationToArmSMELowering
struct TransferReadToArmSMELowering
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;

Expand All @@ -79,15 +94,6 @@ struct TransferReadPermutationToArmSMELowering
return rewriter.notifyMatchFailure(transferReadOp,
"not a 2 result permutation map");

AffineMap map = transferReadOp.getPermutationMap();

// Permutation map doesn't perform permutation, can be lowered to
// vector.load by TransferReadToVectorLoadLowering and then
// arm_sme.tile_load by VectorLoadToArmSMELowering.
if (map.isIdentity())
return rewriter.notifyMatchFailure(
transferReadOp, "map is an identity, apply another pattern");

auto vectorType = transferReadOp.getVectorType();
if (!arm_sme::isValidSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(transferReadOp,
Expand All @@ -96,26 +102,33 @@ struct TransferReadPermutationToArmSMELowering
if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");

if (transferReadOp.getMask())
// TODO: support masking.
return rewriter.notifyMatchFailure(transferReadOp,
"masking not yet supported");

// Out-of-bounds dims are not supported.
if (transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(transferReadOp,
"not inbounds transfer read");

arm_sme::TileSliceLayout layout;

AffineExpr d0, d1;
bindDims(transferReadOp.getContext(), d0, d1);
if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0},
transferReadOp.getContext()))
AffineMap map = transferReadOp.getPermutationMap();
if (map.isIdentity())
layout = arm_sme::TileSliceLayout::Horizontal;
else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
transferReadOp.getContext()))
layout = arm_sme::TileSliceLayout::Vertical;
else
return rewriter.notifyMatchFailure(transferReadOp,
"not true 2-D matrix transpose");
"unsupported permutation map");

// Padding isn't optional for transfer_read, but is only used in the case
// of out-of-bounds accesses (not supported here) and/or masking. Mask is
// optional, if it's not present don't pass padding.
auto mask = transferReadOp.getMask();
auto padding = mask ? transferReadOp.getPadding() : nullptr;
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
transferReadOp, vectorType, transferReadOp.getSource(),
transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical);
transferReadOp.getIndices(), padding, mask, layout);

return success();
}
Expand Down Expand Up @@ -531,7 +544,7 @@ struct VectorOuterProductToArmSMELowering
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
VectorOuterProductToArmSMELowering>(&ctx);
Expand Down
140 changes: 74 additions & 66 deletions mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Original file line number Diff line number Diff line change
@@ -1,181 +1,189 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s

//===----------------------------------------------------------------------===//
// vector.transfer_read (with in-flight transpose)
// vector.transfer_read
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @transfer_read_2d_transpose_i8
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
// CHECK-LABEL: @transfer_read_2d_i8
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
func.func @transfer_read_2d_i8(%src : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i8
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_i16
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
// CHECK-LABEL: @transfer_read_2d_i16
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @transfer_read_2d_i16(%src : memref<?x?xi16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i16
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_i32
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
// CHECK-LABEL: @transfer_read_2d_i32
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
func.func @transfer_read_2d_i32(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i32
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_i64
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
// CHECK-LABEL: @transfer_read_2d_i64
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
func.func @transfer_read_2d_i64(%src : memref<?x?xi64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
"prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_i128
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
// CHECK-LABEL: @transfer_read_2d_i128
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
func.func @transfer_read_2d_i128(%src : memref<?x?xi128>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0 : i128
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
"prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_f16
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
// CHECK-LABEL: @transfer_read_2d_f16
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
func.func @transfer_read_2d_f16(%src : memref<?x?xf16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f16
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_bf16
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
// CHECK-LABEL: @transfer_read_2d_bf16
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
func.func @transfer_read_2d_bf16(%src : memref<?x?xbf16>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : bf16
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_f32
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
// CHECK-LABEL: @transfer_read_2d_f32
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
func.func @transfer_read_2d_f32(%src : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d_transpose_f64
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
// CHECK-LABEL: @transfer_read_2d_f64
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
func.func @transfer_read_2d_f64(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
%0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__bad_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
// CHECK-LABEL: @transfer_read_2d_with_mask_i16
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @transfer_read_2d_with_mask_i16(%src : memref<?x?xi16>, %mask : vector<[8]x[8]xi1>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
"prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
%pad = arith.constant 0 : i16
%0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
"prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__non_memref_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
/// in-flight transpose

// CHECK-LABEL: @transfer_read_2d_transpose_i8
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
%pad = arith.constant 0 : i8
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
"prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32
// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
%0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__bad_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]xf64>) -> ()
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
"prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
return
}

// -----

// CHECK-LABEL: @transfer_read_2d__unsupported_mask
// CHECK-LABEL: @transfer_read_2d__non_memref_type
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
return
}

// -----

/// transfer_read with identity map should be lowered to vector.load by
/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by
/// VectorLoadToArmSMELowering.

// CHECK-LABEL: @transfer_read_2d__non_permuting_map
// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
// CHECK-NOT: arm_sme.tile_load
// CHECK: vector.transfer_read
func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) {
func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f64
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
%0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]xf64>) -> ()
return
}

Expand Down

0 comments on commit 22f1159

Please sign in to comment.