Skip to content

Commit

Permalink
add check (llvm#1098)
Browse files Browse the repository at this point in the history
Signed-off-by: Tong Chen <chentong@us.ibm.com>

Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
chentong319 and AlexandreEichenberger authored Jan 14, 2022
1 parent f6f127c commit 6a19ed8
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,18 +884,33 @@ LogicalResult ONNXSequenceInsertOp::inferShapes(
onnxmlir::SeqType seqType =
input_sequence().getType().dyn_cast<mlir::onnxmlir::SeqType>();
ShapedType tensorType = tensor().getType().dyn_cast<ShapedType>();
ShapedType seqTensorType = seqType.getElementType().cast<ShapedType>();

// Merge the tensor type for the seq and the inserted tensor
// Pick the weaker attr: known dim > unknown dim > unranked
// If inference gets an unranked tensor, no need to update the result

// When the input seq is empty, inherit the tensor type
if (seqType.getLength() == 0) {
getResult().setType(onnxmlir::SeqType::get(tensorType, 1));
return success();
}

// Merge the tensor type for the seq and the inserted tensor
// Pick the weaker attr: known dim > unknown dim > unranked tensor
// If inference gets an unranked tensor, no need to update the result
auto seqShape = seqType.getElementType().cast<ShapedType>().getShape();
auto seqRank = seqType.getElementType().cast<ShapedType>().getRank();
auto newLength = seqType.getLength() == -1 ? -1 : seqType.getLength() + 1;

// When one of the tensor is unranked
if (!tensorType.hasRank()) {
getResult().setType(onnxmlir::SeqType::get(tensorType, newLength));
return success();
}
if (!seqTensorType.hasRank()) {
getResult().setType(onnxmlir::SeqType::get(seqTensorType, newLength));
return success();
}

// Merge when both are ranked
auto seqShape = seqTensorType.getShape();
auto seqRank = seqTensorType.getRank();
if (seqRank == -1)
return success();

Expand All @@ -909,7 +924,7 @@ LogicalResult ONNXSequenceInsertOp::inferShapes(
}
getResult().setType(onnxmlir::SeqType::get(
mlir::RankedTensorType::get(dims, tensorType.getElementType()),
seqType.getLength() == -1 ? -1 : seqType.getLength() + 1));
newLength));

return success();
}
Expand Down

0 comments on commit 6a19ed8

Please sign in to comment.