Skip to content

Commit

Permalink
[mlir][ArmSME] Lower transfer_write + transpose to vertical store
Browse files Browse the repository at this point in the history
This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.
  • Loading branch information
c-rhodes committed Oct 16, 2023
1 parent f5bbd5a commit 14aac43
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 3 deletions.
47 changes: 44 additions & 3 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering

/// Conversion pattern for vector.transfer_write.
///
/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
/// memref<?x?xi8>
/// ---
///
/// Example 1: op with identity permutation map to horizontal
/// arm_sme.tile_store:
///
/// vector.transfer_write %vector, %source[%c0, %c0]
/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
/// vector<[16]x[16]xi8>
/// ---
///
/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
/// (in-flight transpose):
///
/// vector.transfer_write %vector, %source[%c0, %c0]
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
/// : memref<?x?xi8>, vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
Expand All @@ -153,12 +171,35 @@ struct TransferWriteToArmSMELowering
if (!arm_sme::isValidSMETileVectorType(vType))
return failure();

assert(writeOp.getTransferRank() == 2 &&
"expected a permutation_map with result dims of the same rank as "
"the vector type");

if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
return failure();

// Out-of-bounds dims are not supported.
if (writeOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(writeOp,
"not inbounds transfer write");

arm_sme::TileSliceLayout layout;

AffineExpr d0, d1;
bindDims(writeOp.getContext(), d0, d1);
AffineMap map = writeOp.getPermutationMap();
if (map.isIdentity())
layout = arm_sme::TileSliceLayout::Horizontal;
else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
writeOp.getContext()))
layout = arm_sme::TileSliceLayout::Vertical;
else
return rewriter.notifyMatchFailure(writeOp,
"unsupported permutation map");

rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
writeOp.getMask());
writeOp.getMask(), layout);
return success();
}
};
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest

// -----

/// in-flight transpose via vertical store.

// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi64>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
return
}

// -----

/// in-flight transpose via vertical store with mask.

// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xbf16>,
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x[8]xi1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
return
}

// -----

// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.
Expand Down Expand Up @@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
return
}

// -----

// CHECK-LABEL: @transfer_write_2d__out_of_bounds
// CHECK: vector.transfer_write
// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

//===----------------------------------------------------------------------===//
// vector.broadcast
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils

// RUN: %{compile} | %{run} | FileCheck %s

llvm.func @printCString(!llvm.ptr<i8>)

// TODO: replace with vector.print <str> once #68695 lands.
func.func @print_str(%str: !llvm.ptr<array<17 x i8>>) attributes { enable_arm_streaming_ignore } {
%c0 = llvm.mlir.constant(0 : index) : i64
%str_bytes = llvm.getelementptr %str[%c0, %c0]
: (!llvm.ptr<array<17 x i8>>, i64, i64) -> !llvm.ptr<i8>
llvm.call @printCString(%str_bytes) : (!llvm.ptr<i8>) -> ()
return
}

// Vector store.
func.func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%c0 = arith.constant 0.0 : f32
%zero = vector.splat %c0 : vector<[4]x[4]xf32>
vector.transfer_write %zero, %A[%base1, %base2] {in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// Masked vector store.
func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%c0 = arith.constant 0.0 : f32
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
%zero = vector.splat %c0 : vector<[4]x[4]xf32>
vector.transfer_write %zero, %A[%base1, %base2], %mask {in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// Vector store + transpose.
func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// Masked vector store + transpose.
func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// Vector load + print.
func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>

%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>

func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
vector.print %0: vector<[4]x[4]xf32>

return
}

// Allocate heap memory of size 'd0' x 'd1' and initialize.
//
// Example:
//
// initialize_memory(%c4, %c5)
//
// 0, 1, 2, 3, 4
// 10, 11, 12, 13, 14
// 20, 21, 22, 23, 24
// 30, 31, 32, 33, 34
//
// Returns dynamic memref. It's the callers responsiblity to free the returned
// memref.
func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1_f32 = arith.constant 1.0 : f32
%c10_f32 = arith.constant 10.0 : f32

%A = memref.alloc(%d0, %d1) : memref<?x?xf32>

%init = arith.constant 0.0 : f32
scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 {
scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 {
memref.store %inner_val, %A[%i, %j] : memref<?x?xf32>
%inner_val_next = arith.addf %inner_val, %c1_f32 : f32
scf.yield %inner_val_next : f32
}
%val_next = arith.addf %val, %c10_f32 : f32
scf.yield %val_next : f32
}

return %A : memref<?x?xf32>
}

func.func @entry() {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index

// Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
// non-zero offsets while remaining inbounds.
%vscale = vector.vscale
%svl_s = arith.muli %c4, %vscale : index
%svl_s_plus_two = arith.addi %svl_s, %c2 : index

// 1. Initialize memory
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 1, 2, 3
// CHECK-NEXT: ( 10, 11, 12, 13
// CHECK-NEXT: ( 20, 21, 22, 23
// CHECK-NEXT: ( 30, 31, 32, 33
%A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

// 2. Write 2-D vector of zeroes to 1. at offset [2, 2].
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 1, 2, 3
// CHECK-NEXT: ( 10, 11, 12, 13
// CHECK-NEXT: ( 20, 21, 0, 0
// CHECK-NEXT: ( 30, 31, 0, 0
call @transfer_write_2d(%A, %c2, %c2) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

// 3. Write 2-D vector of zeroes to 2. but with mask (nrows=2, ncols=3).
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 0, 0, 3
// CHECK-NEXT: ( 0, 0, 0, 13
// CHECK-NEXT: ( 20, 21, 0, 0
// CHECK-NEXT: ( 30, 31, 0, 0
call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

// 4. Reload 3. + store + transpose.
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 0, 20, 30
// CHECK-NEXT: ( 0, 0, 21, 31
// CHECK-NEXT: ( 0, 0, 0, 0
// CHECK-NEXT: ( 3, 13, 0, 0
call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

// 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2).
// The mask applies after permutation
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 0, 20, 30
// CHECK-NEXT: ( 0, 0, 21, 31
// CHECK-NEXT: ( 20, 21, 0, 0
// CHECK-NEXT: ( 30, 31, 0, 0
call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

memref.dealloc %A : memref<?x?xf32>

return
}

llvm.mlir.global internal constant @tile_begin("TILE BEGIN: \0A\00")

0 comments on commit 14aac43

Please sign in to comment.