From 9af6728974e2f326ee746972af086f5401c281d4 Mon Sep 17 00:00:00 2001 From: zzp_miracle Date: Wed, 14 Dec 2022 11:40:39 +0800 Subject: [PATCH] bufferize arith constant with rank 0 (#868) --- tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc b/tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc index 5cfe9d2e0f6..6c0d626978f 100644 --- a/tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc +++ b/tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc @@ -53,7 +53,7 @@ LogicalResult ConstantOpConverter::matchAndRewrite( auto resultType = op.getType().dyn_cast(); if (!resultType) return failure(); - if (resultType.getRank() != 1) return failure(); + if (resultType.getRank() > 1) return failure(); auto elemType = resultType.getElementType(); if (!elemType.isIndex() && !elemType.isa()) return failure(); @@ -68,7 +68,12 @@ LogicalResult ConstantOpConverter::matchAndRewrite( rewriter.create(loc, en.value().getSExtValue()); if (!elemType.isIndex()) val = rewriter.create(loc, elemType, val); - rewriter.create(loc, val, result, idx); + if (resultType.getRank() == 0) { + rewriter.create(loc, val, result); + } else { + Value idx = rewriter.create(loc, en.index()); + rewriter.create(loc, val, result, idx); + } } rewriter.replaceOp(op, {result});