Skip to content

Commit

Permalink
Introduce attribute delayed in krnl.memset and enable N-D Matmul to 3…
Browse files Browse the repository at this point in the history
…-D MatMul for dynamic shape (#1618)

* Enable N-D Matmul to 3-D MatMul for dynamic shape
Signed-off-by: Tung D. Le <tungld@gmail.com>
  • Loading branch information
tungld authored Aug 23, 2022
1 parent 87be313 commit 53a11a9
Show file tree
Hide file tree
Showing 16 changed files with 153 additions and 91 deletions.
2 changes: 1 addition & 1 deletion docs/SupportedONNXOps-NNPA.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ NNPA has hardware limitations in dimension index size and tensor size, which are
| **LSTM** |14 |- `direction` and `hidden_size` in `W` must have static dimensions.<br>- `R` must have static dimensions.<br>- `B` and `initial_h` have static dimensions if given. `B`'s direction dim must be 1 or 2.<br>- `P`(peepholes), `activation_alpha`, and `activation_beta` are not supported.<br>- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.<br>- `clip` is not supported.<br>- `input_forget` must be default value(0).<br>- `layout` is not supported. | |
| **Log** |13 |Input tensor must have 4 dimensions. | |
| **LogSoftmax** |13 | | |
| **MatMul** |13 |Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2. If M or N > 3, only supports static shape at this moment. | |
| **MatMul** |13 |Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2. | |
| **Max** |13 |- Shape of input tensors must be the same since broadcasting is not supported.<br>- Input tensors must have static dimensions. | |
| **MaxPool** |12 |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.<br>- `ceil_mode` must be default value(0) <br>- Input and output tensors must be 4D tensors(N x C x H x W).<br>- `kernel_shape` must be static.<br>- `ceil_mode` must be default value(0).<br>- `dilations` must be default value(1). | |
| **Min** |13 |- Shape of input tensors must be the same since broadcasting is not supported.<br>- Input tensors must have static dimensions. | |
Expand Down
4 changes: 4 additions & 0 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
optLevel = OptLevel::O2;
else if (optStr == "-O3")
optLevel = OptLevel::O3;
// Lower ONNX to Krnl, ZHigh to ZLow.
addONNXToKrnlPasses(pm, optLevel, /*enableCSE*/ true,
instrumentONNXSignature, ONNXOpStats);

Expand All @@ -150,6 +151,9 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
addKrnlToAffinePasses(pm);
// Normalize MemRefs.
normalizeMemRefsPasses(pm);
// Some Knrl ops, e.g. KrnlMemset, potentially exist and will be lowered
// to Affine when its operands are normalized.
addKrnlToAffinePasses(pm);
// Optimizations at ZLow.
pm.addPass(zlow::createZLowRewritePass());
pm.addPass(mlir::createCanonicalizerPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,25 +318,6 @@ void RewriteONNXForZHighPass::runOnOperation() {
if (!isRankedShapedType(aType) || !isRankedShapedType(bType))
return true;

// Only support static shape at this moment though the code supports dynamic
// shape as well.
//
// The reason is that lowering (3Dx3D) ONNXMatMul of dynamic shape to NNPA
// led to wrong results for the bertsquad-12 model. In particular, the final
// output values were shifted, e.g.
// clang-format off
// at (0, 252) mismatch -0.7646484375 (actual) vs -6.084146022796631 (reference)
// at (0, 253) mismatch -0.7646484375 (actual) vs -6.100776195526123 (reference)
// at (0, 254) mismatch -0.7646484375 (actual) vs -6.13942813873291 (reference)
// at (0, 255) mismatch -0.7646484375 (actual) vs -6.0835771560668945 (reference)
// clang-format on
//
// It is unclear why it happened to dynamic shape.
// There is no accuracy issue if (3Dx3D) ONNXMatMul runs on CPU or has
// static shape.
if (!hasStaticShape(aType) || !hasStaticShape(bType))
return true;

int64_t aRank = getRank(aType);
int64_t bRank = getRank(bType);
if (aRank == 2 && bRank > 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
res = insertAllocAndDeallocZMemRefByDim(dims, layout, op, rewriter);
Value initValue = create.math.constant(rewriter.getF16Type(), 0);
create.krnl.memset(res, initValue);
create.krnl.memset(res, initValue, /*delayed=*/true);
}
return res;
}
Expand Down
1 change: 0 additions & 1 deletion src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,6 @@ void ConvertKrnlToAffinePass::runOnOperation() {
target.addIllegalOp<KrnlMatMulOp>();
target.addIllegalOp<KrnlCopyToBufferOp>();
target.addIllegalOp<KrnlCopyFromBufferOp>();
target.addIllegalOp<KrnlMemsetOp>();
target.addLegalOp<AffineYieldOp>();
target.addLegalOp<AffineLoadOp>();
target.addLegalOp<AffineStoreOp>();
Expand Down
7 changes: 7 additions & 0 deletions src/Conversion/KrnlToAffine/KrnlMemset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,17 @@ class KrnlMemsetLowering : public ConversionPattern {
ConversionPatternRewriter &rewriter) const override {
// Get info from operands.
auto memsetOp = cast<KrnlMemsetOp>(op);
bool delayed = memsetOp.delayed();
KrnlMemsetOpAdaptor operandAdaptor(memsetOp);
Value destMemRef(operandAdaptor.dest());
Value destVal(operandAdaptor.value());
Location loc = memsetOp.getLoc();

// If delayed but the input memref has not normalized yet, do nothing.
if (delayed &&
!destMemRef.getType().cast<MemRefType>().getLayout().isIdentity())
return failure();

AffineBuilderKrnlMem createAffine(rewriter, loc);
IndexExprScope indexScope(createAffine);
MemRefBoundsIndexCapture destBounds(destMemRef);
Expand Down
4 changes: 2 additions & 2 deletions src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ void KrnlBuilder::memcpy(Value dest, Value src, Value size) const {
b.create<KrnlMemcpyOp>(loc, dest, src, size);
}

void KrnlBuilder::memset(Value dest, Value val) const {
b.create<KrnlMemsetOp>(loc, dest, val);
void KrnlBuilder::memset(Value dest, Value val, bool delayed) const {
b.create<KrnlMemsetOp>(loc, dest, val, b.getBoolAttr(delayed));
}

Value KrnlBuilder::strncmp(Value str1, Value str2, Value len) const {
Expand Down
2 changes: 1 addition & 1 deletion src/Dialect/Krnl/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ struct KrnlBuilder : public DialectBuilder {

// C library functions.
void memcpy(mlir::Value dest, mlir::Value src, mlir::Value size) const;
void memset(mlir::Value dest, mlir::Value val) const;
void memset(mlir::Value dest, mlir::Value val, bool delayed = false) const;
mlir::Value strncmp(
mlir::Value str1, mlir::Value str2, mlir::Value len) const;
mlir::Value strlen(mlir::Value str) const;
Expand Down
24 changes: 22 additions & 2 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -1075,10 +1075,30 @@ def KrnlMemsetOp : Op<Krnl_Dialect, "memset", [MemRefsNormalizable,
"$_self.cast<MemRefType>().getElementType()">]> {
let summary = "Set buffer to a given value.";
let description = [{
Krnl operation that set buffer to a given value.
Krnl operation that sets a buffer to a given value.
In case that the buffer is a MemRef with affine_map, `delayed` indicates
whether we set values along original or extended iteration space.

For example, given
- an affine_map `#tile = affine_map < (i)->(i floordiv 4, i mod 4) >`, and
- a buffer of type `memref<5xf32, #tile>`

Original iteration space is along the first axis that has 5 elements.

If we do normalization, the memref becomes `memref<2x4xf32>`. Now we have
an extended iteration space along two axes of sizes 2 and 4, respectively.
This extended iteration space has 8 elements in total.

If `delayed = false`, the original iteration space is used to set values.
In the above example, only 5 out of 8 elementes will be set to the given value.

If `delayed = true`, the extended iteration space is used to set values.
In the above example, all 8 elements will be set to the given value.

}];

let arguments = (ins AnyMemRef:$dest, AnyType: $value);
let arguments = (ins AnyMemRef:$dest, AnyType: $value,
DefaultValuedAttr<BoolAttr, "false">:$delayed);

let assemblyFormat = [{ $dest `,` $value attr-dict `:` type($dest) }];
}
Expand Down
2 changes: 1 addition & 1 deletion test/accelerators/NNPA/backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ set(NNPA_TEST_LIST
# test_lstm_with_peepholes_cpu

# ==OP== MatMul
# ==LIM== Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2. If M or N > 3, only supports static shape at this moment.
# ==LIM== Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2.
test_matmul_2d_cpu,zdnn_matmul_op
test_matmul_3d_cpu,zdnn_matmul_op
test_matmul_4d_cpu,zdnn_matmul_op
Expand Down
69 changes: 34 additions & 35 deletions test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -404,38 +404,37 @@ func.func @test_matmul_broadcast_2(%arg0: tensor<256x256xf32>, %arg1: tensor<4x1

// -----

// COM: enable this when the bug about dynamic shape is fixed.
// COM: func.func @test_matmul_broadcast_dyn_dims(%arg0: tensor<256x?xf32>, %arg1: tensor<4x12x?x?xf32>) -> (tensor<4x12x256x?xf32>) {
// COM: %0= "onnx.MatMul"(%arg0, %arg1) : (tensor<256x?xf32>, tensor<4x12x?x?xf32>) -> tensor<4x12x256x?xf32>
// COM: return %0 : tensor<4x12x256x?xf32>
// COM:
// COM: // MATMUL-LABEL: func.func @test_matmul_broadcast_dyn_dims
// COM: // MATMUL-SAME: ([[PARAM_0_:%.+]]: tensor<256x?xf32>, [[PARAM_1_:%.+]]: tensor<4x12x?x?xf32>) -> tensor<?x?x?x?xf32> {
// COM: // MATMUL-DAG: [[VAR_0_:%.+]] = "onnx.Shape"([[PARAM_1_]]) : (tensor<4x12x?x?xf32>) -> tensor<4xi64>
// COM: // MATMUL-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_2_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_3_:%.+]] = "onnx.Constant"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_4_:%.+]] = "onnx.Constant"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_5_:%.+]] = "onnx.Constant"() {value = dense<4> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL: [[VAR_6_:%.+]] = "onnx.Slice"([[VAR_0_]], [[VAR_4_]], [[VAR_5_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// COM: // MATMUL: [[VAR_7_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_6_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64>
// COM: // MATMUL: [[VAR_8_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_7_]]) {allowzero = 0 : si64} : (tensor<4x12x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// COM: // MATMUL-DAG: [[VAR_9_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_8_]]) : (tensor<256x?xf32>, tensor<?x?x?xf32>) -> tensor<?x256x?xf32>
// COM: // MATMUL-DAG: [[VAR_10_:%.+]] = "onnx.Shape"([[PARAM_0_]]) : (tensor<256x?xf32>) -> tensor<2xi64>
// COM: // MATMUL-DAG: [[VAR_11_:%.+]] = "onnx.Shape"([[PARAM_1_]]) : (tensor<4x12x?x?xf32>) -> tensor<4xi64>
// COM: // MATMUL-DAG: [[VAR_12_:%.+]] = "onnx.Constant"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_13_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_14_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_15_:%.+]] = "onnx.Constant"() {value = dense<4> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_16_:%.+]] = "onnx.Constant"() {value = dense<3> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_17_:%.+]] = "onnx.Constant"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_18_:%.+]] = "onnx.Constant"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
// COM: // MATMUL-NOT: separator of consecutive DAGs
// COM: // MATMUL-DAG: [[VAR_19_:%.+]] = "onnx.Slice"([[VAR_11_]], [[VAR_12_]], [[VAR_18_]], [[VAR_12_]], [[VAR_13_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// COM: // MATMUL-DAG: [[VAR_20_:%.+]] = "onnx.Slice"([[VAR_10_]], [[VAR_17_]], [[VAR_14_]], [[VAR_12_]], [[VAR_13_]]) : (tensor<2xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
// COM: // MATMUL-DAG: [[VAR_21_:%.+]] = "onnx.Slice"([[VAR_11_]], [[VAR_16_]], [[VAR_15_]], [[VAR_12_]], [[VAR_13_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
// COM: // MATMUL: [[VAR_22_:%.+]] = "onnx.Concat"([[VAR_19_]], [[VAR_20_]], [[VAR_21_]]) {axis = 0 : si64} : (tensor<2xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
// COM: // MATMUL: [[VAR_23_:%.+]] = "onnx.Reshape"([[VAR_9_]], [[VAR_22_]]) {allowzero = 0 : si64} : (tensor<?x256x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// COM: // MATMUL: return [[VAR_23_]] : tensor<?x?x?x?xf32>
// COM: // MATMUL: }
// COM: }
func.func @test_matmul_broadcast_dyn_dims(%arg0: tensor<256x?xf32>, %arg1: tensor<4x12x?x?xf32>) -> (tensor<4x12x256x?xf32>) {
%0= "onnx.MatMul"(%arg0, %arg1) : (tensor<256x?xf32>, tensor<4x12x?x?xf32>) -> tensor<4x12x256x?xf32>
return %0 : tensor<4x12x256x?xf32>

// MATMUL-LABEL: func.func @test_matmul_broadcast_dyn_dims
// MATMUL-SAME: ([[PARAM_0_:%.+]]: tensor<256x?xf32>, [[PARAM_1_:%.+]]: tensor<4x12x?x?xf32>) -> tensor<?x?x?x?xf32> {
// MATMUL-DAG: [[VAR_0_:%.+]] = "onnx.Shape"([[PARAM_1_]]) : (tensor<4x12x?x?xf32>) -> tensor<4xi64>
// MATMUL-DAG: [[VAR_1_:%.+]] = "onnx.Constant"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_2_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_3_:%.+]] = "onnx.Constant"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_4_:%.+]] = "onnx.Constant"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_5_:%.+]] = "onnx.Constant"() {value = dense<4> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL: [[VAR_6_:%.+]] = "onnx.Slice"([[VAR_0_]], [[VAR_4_]], [[VAR_5_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// MATMUL: [[VAR_7_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_6_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64>
// MATMUL: [[VAR_8_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_7_]]) {allowzero = 0 : si64} : (tensor<4x12x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// MATMUL-DAG: [[VAR_9_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_8_]]) : (tensor<256x?xf32>, tensor<?x?x?xf32>) -> tensor<?x256x?xf32>
// MATMUL-DAG: [[VAR_10_:%.+]] = "onnx.Shape"([[PARAM_0_]]) : (tensor<256x?xf32>) -> tensor<2xi64>
// MATMUL-DAG: [[VAR_11_:%.+]] = "onnx.Shape"([[PARAM_1_]]) : (tensor<4x12x?x?xf32>) -> tensor<4xi64>
// MATMUL-DAG: [[VAR_12_:%.+]] = "onnx.Constant"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_13_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_14_:%.+]] = "onnx.Constant"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_15_:%.+]] = "onnx.Constant"() {value = dense<4> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_16_:%.+]] = "onnx.Constant"() {value = dense<3> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_17_:%.+]] = "onnx.Constant"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-DAG: [[VAR_18_:%.+]] = "onnx.Constant"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
// MATMUL-NOT: separator of consecutive DAGs
// MATMUL-DAG: [[VAR_19_:%.+]] = "onnx.Slice"([[VAR_11_]], [[VAR_12_]], [[VAR_18_]], [[VAR_12_]], [[VAR_13_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// MATMUL-DAG: [[VAR_20_:%.+]] = "onnx.Slice"([[VAR_10_]], [[VAR_17_]], [[VAR_14_]], [[VAR_12_]], [[VAR_13_]]) : (tensor<2xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
// MATMUL-DAG: [[VAR_21_:%.+]] = "onnx.Slice"([[VAR_11_]], [[VAR_16_]], [[VAR_15_]], [[VAR_12_]], [[VAR_13_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
// MATMUL: [[VAR_22_:%.+]] = "onnx.Concat"([[VAR_19_]], [[VAR_20_]], [[VAR_21_]]) {axis = 0 : si64} : (tensor<2xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64>
// MATMUL: [[VAR_23_:%.+]] = "onnx.Reshape"([[VAR_9_]], [[VAR_22_]]) {allowzero = 0 : si64} : (tensor<?x256x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// MATMUL: return [[VAR_23_]] : tensor<?x?x?x?xf32>
// MATMUL: }
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func.func @gru_no_intial_h(%input : tensor<?x?x7xf32, #zhigh.encoding<{dataLayou
// CHECK: krnl.store [[VAR_c7_i64_]], [[RES_1_]]{{.}}[[VAR_c3_]]{{.}} : memref<5xi64>
// CHECK: krnl.store [[VAR_c9_i64_]], [[RES_1_]]{{.}}[[VAR_c4_]]{{.}} : memref<5xi64>
// CHECK: [[RES_2_:%.+]] = memref.alloc([[VAR_1_]]) {{.*}}: memref<1x?x9xf16, #map0>
// CHECK: krnl.memset [[RES_2_]], [[VAR_cst_]] : memref<1x?x9xf16, #map0>
// CHECK: krnl.memset [[RES_2_]], [[VAR_cst_]] {delayed = true} : memref<1x?x9xf16, #map0>
// CHECK-DAG: [[VAR_7_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref<?x?x7xf16, #map0>
// CHECK-DAG: [[VAR_8_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<?x?x7xf16, #map0>
// CHECK-NOT: separator of consecutive DAGs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ func.func @lstm_no_intial_h_and_c(%input : tensor<?x?x7xf32, #zhigh.encoding<{da
// CHECK: krnl.store [[VAR_c7_i64_]], [[RES_2_]]{{.}}[[VAR_c3_]]{{.}} : memref<5xi64>
// CHECK: krnl.store [[VAR_c9_i64_]], [[RES_2_]]{{.}}[[VAR_c4_]]{{.}} : memref<5xi64>
// CHECK: [[RES_3_:%.+]] = memref.alloc([[VAR_1_]]) {{.*}}: memref<1x?x9xf16, #map0>
// CHECK: krnl.memset [[RES_3_]], [[VAR_cst_]] : memref<1x?x9xf16, #map0>
// CHECK: krnl.memset [[RES_3_]], [[VAR_cst_]] {delayed = true} : memref<1x?x9xf16, #map0>
// CHECK: [[RES_4_:%.+]] = memref.alloc([[VAR_1_]]) {{.*}}: memref<1x?x9xf16, #map0>
// CHECK: krnl.memset [[RES_4_]], [[VAR_cst_]] : memref<1x?x9xf16, #map0>
// CHECK: krnl.memset [[RES_4_]], [[VAR_cst_]] {delayed = true} : memref<1x?x9xf16, #map0>
// CHECK-DAG: [[VAR_9_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref<?x?x7xf16, #map0>
// CHECK-DAG: [[VAR_10_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<?x?x7xf16, #map0>
// CHECK-NOT: separator of consecutive DAGs
Expand Down
Loading

0 comments on commit 53a11a9

Please sign in to comment.