Skip to content

Commit

Permalink
bufferize arith constant with rank 0 (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzpmiracle committed Dec 14, 2022
1 parent ae8954a commit 9af6728
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tao_compiler/mlir/disc/transforms/disc_std_bufferize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ LogicalResult ConstantOpConverter::matchAndRewrite(
auto resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType) return failure();

if (resultType.getRank() != 1) return failure();
if (resultType.getRank() > 1) return failure();

auto elemType = resultType.getElementType();
if (!elemType.isIndex() && !elemType.isa<IntegerType>()) return failure();
Expand All @@ -68,7 +68,12 @@ LogicalResult ConstantOpConverter::matchAndRewrite(
rewriter.create<arith::ConstantIndexOp>(loc, en.value().getSExtValue());
if (!elemType.isIndex())
val = rewriter.create<arith::IndexCastOp>(loc, elemType, val);
rewriter.create<memref::StoreOp>(loc, val, result, idx);
if (resultType.getRank() == 0) {
rewriter.create<memref::StoreOp>(loc, val, result);
} else {
Value idx = rewriter.create<arith::ConstantIndexOp>(loc, en.index());
rewriter.create<memref::StoreOp>(loc, val, result, idx);
}
}

rewriter.replaceOp(op, {result});
Expand Down

0 comments on commit 9af6728

Please sign in to comment.