Skip to content

Commit

Permalink
Merge pull request #2 from FloatingcloudKnight/vector
Browse files Browse the repository at this point in the history
[Vectorization/pooling] update Pooling Nhwc Max Vectorization pass and relevant case
  • Loading branch information
FloatingcloudKnight authored Aug 28, 2024
2 parents f7fd94b + 76706e7 commit 69600d8
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 181 deletions.
110 changes: 70 additions & 40 deletions examples/MLIRLinalg/linalg-pooling-nhwc-max.mlir
Original file line number Diff line number Diff line change
@@ -1,50 +1,80 @@
// RUN: buddy-opt %s \
// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \
// RUN: -convert-linalg-to-loops \
// RUN: -lower-affine \
// RUN: -convert-scf-to-cf \
// RUN: -convert-vector-to-llvm \
// RUN: -finalize-memref-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

#map0 = affine_map<(d0, d1) -> (d0 + d1 - 1)>

module{
memref.global "private" @input : memref<1x4x4x1xf32> =
dense<[[[[1.], [1.], [1.], [1.]],
[[1.], [1.], [1.], [1.]],
[[1.], [1.], [1.], [1.]],
[[1.], [1.], [1.], [1.]]]]>
memref.global "private" @kernel : memref<2x2xf32> = dense<0.0>
memref.global "private" @output : memref<1x3x3x1xf32> = dense<0.0>

func.func private @printMemrefF32(memref<*xf32>)

func.func @pooling_nhwc_max(%a : memref<?x?x?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?x?x?xf32>) {
linalg.pooling_nhwc_max
ins(%a, %b : memref<?x?x?x?xf32>, memref<?x?xf32>)
outs(%c : memref<?x?x?x?xf32>)
return
}
memref.global "private" @kernel : memref<3x3xf32> = dense<0.0>

func.func private @rtclock() -> f64
func.func private @printMemrefF32(memref<*xf32>)

func.func @main(){
%input = memref.get_global @input : memref<1x4x4x1xf32>
%kernel = memref.get_global @kernel : memref<2x2xf32>
%output = memref.get_global @output : memref<1x3x3x1xf32>

%a = memref.cast %input : memref<1x4x4x1xf32> to memref<?x?x?x?xf32>
%b = memref.cast %kernel : memref<2x2xf32> to memref<?x?xf32>
%c = memref.cast %output : memref<1x3x3x1xf32> to memref<?x?x?x?xf32>

call @pooling_nhwc_max(%a, %b, %c) : (memref<?x?x?x?xf32>, memref<?x?xf32>, memref<?x?x?x?xf32>) -> ()
// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [1],
%print_c = memref.cast %c : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_c) : (memref<*xf32>) -> ()

return
func.func @pooling_nhwc_max(%a : memref<?x?x?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?x?x?xf32>) {
linalg.pooling_nhwc_max
ins(%a, %b : memref<?x?x?x?xf32>, memref<?x?xf32>)
outs(%c : memref<?x?x?x?xf32>)
return
}

func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
scf.for %arg5 = %c0 to %arg0 step %c1 {
scf.for %arg6 = %c0 to %arg1 step %c1 {
scf.for %arg7 = %c0 to %arg2 step %c1 {
scf.for %arg8 = %c0 to %arg3 step %c1 {
memref.store %arg4, %0[%arg5, %arg6, %arg7, %arg8] : memref<?x?x?x?xf32>
}
}
}
}
return %0 : memref<?x?x?x?xf32>
}

func.func @main(){
%N = arith.constant 1 : index
%current_v1 = arith.constant 3 : index
%current_v2 = arith.constant 126 : index
%current_v0 = affine.apply #map0(%current_v2, %current_v1)
%c0 = arith.constant 0.000000e+00 : f32
%c1 = arith.constant 1.000000e+00 : f32
%kernel = memref.get_global @kernel : memref<3x3xf32>

%a = call @alloc_f32(%N, %current_v0, %current_v0, %N, %c1) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>
%b = memref.cast %kernel : memref<3x3xf32> to memref<?x?xf32>
%c = call @alloc_f32(%N, %current_v2, %current_v2, %N, %c0) : (index, index, index, index, f32) -> memref<?x?x?x?xf32>

%t0 = call @rtclock() : () -> f64
call @pooling_nhwc_max(%a, %b, %c) : (memref<?x?x?x?xf32>, memref<?x?xf32>, memref<?x?x?x?xf32>) -> ()
%t1 = call @rtclock() : () -> f64
// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [1],
%print_c = memref.cast %c : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_c) : (memref<*xf32>) -> ()

%time = arith.subf %t1, %t0 : f64
vector.print %time : f64

memref.dealloc %a : memref<?x?x?x?xf32>
memref.dealloc %c : memref<?x?x?x?xf32>

return
}
}
21 changes: 14 additions & 7 deletions examples/MLIRLinalg/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,21 @@ linalg-matmul-vectorization-lower:
-matmul-vectorization \
-o log.mlir

linalg-pooling-nhwc-max-lower:
@${BUDDY_OPT} linalg-pooling-nhwc-max.mlir ${MLIR_OPT_OPTIONS} \
linalg-pooling-nhwc-max-vectorization-lower:
@${BUDDY_OPT} linalg-pooling-nhwc-max.mlir \
-pooling-nhwc-max-vectorization \
-o log.mlir

linalg-pooling-nhwc-max-run:
@${MLIR_OPT} linalg-pooling-nhwc-max.mlir ${MLIR_OPT_OPTIONS} \
-convert-linalg-to-loops -lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \
-convert-func-to-llvm -reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}
@${MLIR_OPT} linalg-pooling-nhwc-max.mlir \
-convert-linalg-to-loops \
-lower-affine -convert-scf-to-cf \
-convert-vector-to-llvm \
-finalize-memref-to-llvm \
-convert-arith-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main \
-entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} \
-shared-libs=${MLIR_C_RUNNER_UTILS}
34 changes: 25 additions & 9 deletions examples/MLIRVector/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -729,12 +729,28 @@ vector-iteration-run:

vector-pooling-nhwc-max-run:
@${MLIR_OPT} ./vector-pooling-nhwc-max.mlir \
-convert-vector-to-scf \
-lower-affine \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-finalize-memref-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}
--lower-affine \
--convert-scf-to-cf \
--convert-vector-to-llvm \
--finalize-memref-to-llvm \
--convert-arith-to-llvm \
--convert-func-to-llvm \
--reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main \
-entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} \
-shared-libs=${MLIR_C_RUNNER_UTILS}

vector-pooling-nhwc-max-test:
@${MLIR_OPT} ./vector-pooling-nhwc-max-1.mlir \
--lower-affine \
--convert-scf-to-cf \
--convert-vector-to-llvm \
--finalize-memref-to-llvm \
--convert-arith-to-llvm \
--convert-func-to-llvm \
--reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main \
-entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} \
-shared-libs=${MLIR_C_RUNNER_UTILS}
181 changes: 108 additions & 73 deletions examples/MLIRVector/vector-pooling-nhwc-max.mlir
Original file line number Diff line number Diff line change
@@ -1,93 +1,128 @@
// RUN: buddy-opt %s \
// RUN: -convert-vector-to-scf -lower-affine -convert-scf-to-cf \
// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm \
// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \
// RUN: -convert-vector-to-scf \
// RUN: -lower-affine \
// RUN: -convert-scf-to-cf \
// RUN: -convert-vector-to-llvm \
// RUN: -finalize-memref-to-llvm \
// RUN: -llvm-request-c-wrappers \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

#map = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0) -> (d0 ceildiv 32)>
#map2 = affine_map<(d0, d1) -> (d0 + d1 - 1)>

module{
memref.global "private" @input : memref<1x4x4x1xf32> =
dense<[[[[1.], [1.], [1.], [1.]],
[[1.], [1.], [1.], [1.]],
[[1.], [1.], [1.], [1.]],
[[1.], [1.], [1.], [1.]]]]>
memref.global "private" @kernel : memref<2x2xf32> = dense<0.0>
memref.global "private" @output : memref<1x3x3x1xf32> = dense<0.0>
module {
memref.global "private" @kernel : memref<3x3xi32> = dense<0>

func.func private @printMemrefF32(memref<*xf32>)

func.func @pooling_nhwc_max(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c32 = arith.constant 32 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.splat %cst : vector<32xf32>
%dim = memref.dim %arg1, %c0 : memref<?x?xf32>
%dim_0 = memref.dim %arg1, %c1 : memref<?x?xf32>
%dim_1 = memref.dim %arg2, %c0 : memref<?x?x?x?xf32>
%dim_2 = memref.dim %arg2, %c1 : memref<?x?x?x?xf32>
%dim_3 = memref.dim %arg2, %c2 : memref<?x?x?x?xf32>
%dim_4 = memref.dim %arg2, %c3 : memref<?x?x?x?xf32>
affine.for %tmp0 = #map(%c0) to #map(%dim_1) {
affine.for %tmp1 = #map(%c0) to #map(%dim_4) {
affine.for %tmp2 = #map(%c0) to #map(%dim_2) {
affine.for %tmp3 = #map(%c0) to #map(%dim) {
affine.for %tmp4 = #map(%c0) to #map(%dim_0) {
affine.for %tmp5 = #map(%c0) to #map1(%dim_3) {
%1 = arith.muli %tmp5, %c32 : index
%2 = arith.subi %dim_3, %1 : index
%3 = arith.cmpi sge, %2, %c32 : index
scf.if %3 {
%4 = affine.vector_load %arg0[%tmp0, %tmp2 + %tmp3, %tmp4 + %tmp5 * 32, %tmp1] : memref<?x?x?x?xf32>, vector<32xf32>
%5 = affine.vector_load %arg2[%tmp0, %tmp2, %tmp5 * 32, %tmp1] : memref<?x?x?x?xf32>, vector<32xf32>
%6 = arith.maximumf %4, %5 : vector<32xf32>
affine.vector_store %6, %arg2[%tmp0, %tmp2, %tmp5 * 32, %tmp1] : memref<?x?x?x?xf32>, vector<32xf32>
} else {
%7 = vector.create_mask %2 : vector<32xi1>
%8 = arith.addi %tmp2, %tmp3 : index
%9 = arith.muli %tmp5, %c32 : index
%10 = arith.addi %tmp4, %9 : index
%11 = vector.maskedload %arg0[%tmp0, %8, %10, %tmp1], %7, %0 : memref<?x?x?x?xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
%12 = vector.maskedload %arg2[%tmp0, %tmp2, %9, %tmp1], %7, %0 : memref<?x?x?x?xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
%13 = arith.maximumf %11, %12 : vector<32xf32>
vector.maskedstore %arg2[%tmp0, %tmp2, %9, %tmp1], %7, %13 : memref<?x?x?x?xf32>, vector<32xi1>, vector<32xf32>
}
func.func private @rtclock() -> f64
func.func private @printMemrefI32(memref<*xi32>)
func.func @pooling_nhwc_max(%arg0: memref<?x?x?x?xi32>, %arg1: memref<?x?xi32>, %arg2: memref<?x?x?x?xi32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c32 = arith.constant 32 : index
%c0_i32 = arith.constant 0 : i32
%0 = vector.splat %c0_i32 : vector<32xi32>
%dim = memref.dim %arg1, %c0 : memref<?x?xi32>
%dim_0 = memref.dim %arg1, %c1 : memref<?x?xi32>
%dim_1 = memref.dim %arg2, %c0 : memref<?x?x?x?xi32>
%dim_2 = memref.dim %arg2, %c1 : memref<?x?x?x?xi32>
%dim_3 = memref.dim %arg2, %c2 : memref<?x?x?x?xi32>
%dim_4 = memref.dim %arg2, %c3 : memref<?x?x?x?xi32>
affine.for %arg3 = #map(%c0) to #map(%dim_1) {
affine.for %arg4 = #map(%c0) to #map(%dim_4) {
affine.for %arg5 = #map(%c0) to #map(%dim_2) {
affine.for %arg6 = #map(%c0) to #map1(%dim_3) {
%1 = arith.muli %arg6, %c32 : index
%2 = arith.subi %dim_3, %1 : index
%3 = arith.cmpi sge, %2, %c32 : index
scf.if %3 {
%4 = affine.vector_load %arg2[%arg3, %arg5, %arg6 * 32, %arg4] : memref<?x?x?x?xi32>, vector<32xi32>
%5 = affine.for %arg7 = #map(%c0) to #map(%dim) iter_args(%arg8 = %4) -> (vector<32xi32>) {
%6 = affine.for %arg9 = #map(%c0) to #map(%dim_0) iter_args(%arg10 = %arg8) -> (vector<32xi32>) {
%7 = affine.vector_load %arg0[%arg3, %arg7 + %arg5, %arg9 + %arg6 * 32, %arg4] : memref<?x?x?x?xi32>, vector<32xi32>
%8 = arith.maxsi %7, %arg10 : vector<32xi32>
affine.yield %8 : vector<32xi32>
}
affine.yield %6 : vector<32xi32>
}
affine.vector_store %5, %arg2[%arg3, %arg5, %arg6 * 32, %arg4] : memref<?x?x?x?xi32>, vector<32xi32>
} else {
%4 = vector.create_mask %2 : vector<32xi1>
%5 = vector.maskedload %arg2[%arg3, %arg5, %1, %arg4], %4, %0 : memref<?x?x?x?xi32>, vector<32xi1>, vector<32xi32> into vector<32xi32>
%6 = affine.for %arg7 = #map(%c0) to #map(%dim) iter_args(%arg8 = %5) -> (vector<32xi32>) {
%7 = affine.for %arg9 = #map(%c0) to #map(%dim_0) iter_args(%arg10 = %arg8) -> (vector<32xi32>) {
%8 = arith.addi %arg5, %arg7 : index
%9 = arith.addi %arg9, %1 : index
%10 = vector.maskedload %arg0[%arg3, %8, %9, %arg4], %4, %0 : memref<?x?x?x?xi32>, vector<32xi1>, vector<32xi32> into vector<32xi32>
%11 = arith.maxsi %10, %arg10 : vector<32xi32>
affine.yield %11 : vector<32xi32>
}
affine.yield %7 : vector<32xi32>
}
vector.maskedstore %arg2[%arg3, %arg5, %1, %arg4], %4, %6 : memref<?x?x?x?xi32>, vector<32xi1>, vector<32xi32>
}
}
}
}
return
}
return
}

func.func @main(){
%input = memref.get_global @input : memref<1x4x4x1xf32>
%kernel = memref.get_global @kernel : memref<2x2xf32>
%output = memref.get_global @output : memref<1x3x3x1xf32>
func.func @alloc_i32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: i32) -> memref<?x?x?x?xi32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xi32>
scf.for %arg5 = %c0 to %arg0 step %c1 {
scf.for %arg6 = %c0 to %arg1 step %c1 {
scf.for %arg7 = %c0 to %arg2 step %c1 {
scf.for %arg8 = %c0 to %arg3 step %c1 {
memref.store %arg4, %0[%arg5, %arg6, %arg7, %arg8] : memref<?x?x?x?xi32>
}
}
}
}
return %0 : memref<?x?x?x?xi32>
}

%a = memref.cast %input : memref<1x4x4x1xf32> to memref<?x?x?x?xf32>
%b = memref.cast %kernel : memref<2x2xf32> to memref<?x?xf32>
%c = memref.cast %output : memref<1x3x3x1xf32> to memref<?x?x?x?xf32>
func.func @main() {
%N = arith.constant 1 : index
%current_v1 = arith.constant 3 : index
%current_v2 = arith.constant 126 : index
%current_v0 = affine.apply #map2(%current_v2, %current_v1)
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%kernel = memref.get_global @kernel : memref<3x3xi32>

call @pooling_nhwc_max(%a, %b, %c) : (memref<?x?x?x?xf32>, memref<?x?xf32>, memref<?x?x?x?xf32>) -> ()
// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [1],
%print_c = memref.cast %c : memref<?x?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%print_c) : (memref<*xf32>) -> ()
%a = call @alloc_i32(%N, %current_v0, %current_v0, %N, %c1) : (index, index, index, index, i32) -> memref<?x?x?x?xi32>
%b = memref.cast %kernel : memref<3x3xi32> to memref<?x?xi32>
%c = call @alloc_i32(%N, %current_v2, %current_v2, %N, %c0) : (index, index, index, index, i32) -> memref<?x?x?x?xi32>

return
}
%t0 = call @rtclock() : () -> f64
call @pooling_nhwc_max(%a, %b, %c) : (memref<?x?x?x?xi32>, memref<?x?xi32>, memref<?x?x?x?xi32>) -> ()
%t1 = call @rtclock() : () -> f64
// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref
// CHECK: [
// CHECK: [
// CHECK: [
// CHECK: [1],
%print_c = memref.cast %c : memref<?x?x?x?xi32> to memref<*xi32>
call @printMemrefI32(%print_c) : (memref<*xi32>) -> ()

%time = arith.subf %t1, %t0 : f64
vector.print %time : f64

memref.dealloc %a : memref<?x?x?x?xi32>
memref.dealloc %c : memref<?x?x?x?xi32>

return
}
}
Loading

0 comments on commit 69600d8

Please sign in to comment.