diff --git a/tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc b/tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc index 5cfe9d2e0f6..6c0d626978f 100644 --- a/tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc +++ b/tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc @@ -53,7 +53,7 @@ LogicalResult ConstantOpConverter::matchAndRewrite( auto resultType = op.getType().dyn_cast(); if (!resultType) return failure(); - if (resultType.getRank() != 1) return failure(); + if (resultType.getRank() > 1) return failure(); auto elemType = resultType.getElementType(); if (!elemType.isIndex() && !elemType.isa()) return failure(); @@ -68,7 +68,12 @@ LogicalResult ConstantOpConverter::matchAndRewrite( rewriter.create(loc, en.value().getSExtValue()); if (!elemType.isIndex()) val = rewriter.create(loc, elemType, val); - rewriter.create(loc, val, result, idx); + if (resultType.getRank() == 0) { + rewriter.create(loc, val, result); + } else { + Value idx = rewriter.create(loc, en.index()); + rewriter.create(loc, val, result, idx); + } } rewriter.replaceOp(op, {result});