Skip to content

Commit

Permalink
Fix for lstm and gru.
Browse files Browse the repository at this point in the history
Signed-off-by: Haruki Imai <imaihal@jp.ibm.com>
  • Loading branch information
imaihal committed Aug 27, 2024
1 parent 2f23ff8 commit 7a5fb6d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 22 deletions.
3 changes: 1 addition & 2 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();

auto valueAttr = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(
zhighStickifiedConstOp.getValueAttr());
auto valueAttr = zhighStickifiedConstOp.getValueAttr();

// Create a ZLowStickifiedConstantOp.
// Set nullptr in the valueAttr when it is initialized with zero later.
Expand Down
4 changes: 2 additions & 2 deletions test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ func.func @gru_no_input_and_hidden_biases(%input : tensor<?x?x7xf16, #zhigh.layo
// CHECK: krnl.store [[VAR_1_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<5xi64>
// CHECK: krnl.store [[VAR_c7_i64_]], [[RES_1_]]{{.}}[[VAR_c3_]]{{.}} : memref<5xi64>
// CHECK: krnl.store [[VAR_c9_i64_]], [[RES_1_]]{{.}}[[VAR_c4_]]{{.}} : memref<5xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_0", shape = [1, 3, 1, 1, 32, 64], value = dense_resource<zhigh> : tensor<12288xi8>} : () -> memref<1x3x1x1x32x64xf16>
// CHECK-DAG: [[VAR_3_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_1", shape = [1, 3, 1, 1, 32, 64], value = dense_resource<zhigh_1> : tensor<12288xi8>} : () -> memref<1x3x1x1x32x64xf16>
// CHECK-DAG: [[VAR_2_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_0", offset = 0 : i64, shape = [1, 3, 1, 1, 32, 64]} : () -> memref<1x3x1x1x32x64xf16>
// CHECK-DAG: [[VAR_3_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_1", offset = 0 : i64, shape = [1, 3, 1, 1, 32, 64]} : () -> memref<1x3x1x1x32x64xf16>
// CHECK-DAG: [[VAR_dim_2_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref<?x?x7xf16, [[MAP_0_]]>
// CHECK-DAG: [[VAR_dim_3_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<?x?x7xf16, [[MAP_0_]]>
// CHECK-NOT: separator of consecutive DAGs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ func.func @lstm_no_input_and_hidden_biases(%input : tensor<?x?x7xf16, #zhigh.lay
// CHECK: krnl.store [[VAR_1_]], [[RES_2_]]{{.}}[[VAR_c2_]]{{.}} : memref<5xi64>
// CHECK: krnl.store [[VAR_c7_i64_]], [[RES_2_]]{{.}}[[VAR_c3_]]{{.}} : memref<5xi64>
// CHECK: krnl.store [[VAR_c9_i64_]], [[RES_2_]]{{.}}[[VAR_c4_]]{{.}} : memref<5xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_0", shape = [1, 4, 1, 1, 32, 64], value = dense_resource<zhigh> : tensor<16384xi8>} : () -> memref<1x4x1x1x32x64xf16>
// CHECK-DAG: [[VAR_3_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_1", shape = [1, 4, 1, 1, 32, 64], value = dense_resource<zhigh_1> : tensor<16384xi8>} : () -> memref<1x4x1x1x32x64xf16>
// CHECK-DAG: [[VAR_2_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_0", offset = 0 : i64, shape = [1, 4, 1, 1, 32, 64]} : () -> memref<1x4x1x1x32x64xf16>
// CHECK-DAG: [[VAR_3_:%.+]] = "zlow.stickifiedConstant"() {alignment = 4096 : i64, name = "constant_stickify_1", offset = 0 : i64, shape = [1, 4, 1, 1, 32, 64]} : () -> memref<1x4x1x1x32x64xf16>
// CHECK-DAG: [[VAR_dim_3_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref<?x?x7xf16, [[MAP_0_]]>
// CHECK-DAG: [[VAR_dim_4_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<?x?x7xf16, [[MAP_0_]]>
// CHECK-NOT: separator of consecutive DAGs
Expand Down
28 changes: 12 additions & 16 deletions test/mlir/accelerators/nnpa/driver/ccfd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
// COM: It is the necessary condition to get the best performance.

CHECK-LABEL: func.func @main_graph
CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: memref.alloc
CHECK-NEXT: zlow.stick
CHECK-DAG: zlow.stickifiedConstant

CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: memref.alloc
CHECK-DAG: memref.alloc
CHECK-DAG: krnl.global
Expand All @@ -24,12 +22,10 @@ CHECK-NEXT: zlow.lstm
CHECK-NOT: zlow.stick
CHECK-NOT: zlow.unstick

CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: memref.alloc
CHECK-DAG: memref.alloc
CHECK-DAG: krnl.global
Expand All @@ -40,17 +36,17 @@ CHECK-NEXT: zlow.lstm
CHECK-NOT: zlow.stick
CHECK-NOT: zlow.unstick

CHECK-DAG: krnl.global
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: memref.alloc
CHECK-DAG: krnl.global
CHECK-DAG: krnl.global
CHECK-DAG: zlow.stickifiedConstant
CHECK-NEXT: zlow.matmul

// No stick and unstick in between.
CHECK-NOT: zlow.stick
CHECK-NOT: zlow.unstick

CHECK-DAG: krnl.global
CHECK-DAG: zlow.stickifiedConstant
CHECK-DAG: memref.alloc
CHECK-DAG: krnl.global
CHECK-NEXT: zlow.add
Expand Down

0 comments on commit 7a5fb6d

Please sign in to comment.