Skip to content

Commit

Permalink
[mlir][NFC] update code to use mlir::dyn_cast/cast/isa (#90633)
Browse files Browse the repository at this point in the history
Fix compiler warning caused by using deprecated interface
(#90413)
  • Loading branch information
Peiming Liu authored Apr 30, 2024
1 parent 49bb993 commit d235369
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 55 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Tools/PDLL/AST/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
}

/// Return the range result type of this expression.
RangeType getType() const { return Base::getType().cast<RangeType>(); }
RangeType getType() const { return mlir::cast<RangeType>(Base::getType()); }

private:
RangeExpr(SMRange loc, RangeType type, unsigned numElements)
Expand Down Expand Up @@ -630,7 +630,7 @@ class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
}

/// Return the tuple result type of this expression.
TupleType getType() const { return Base::getType().cast<TupleType>(); }
TupleType getType() const { return mlir::cast<TupleType>(Base::getType()); }

private:
TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();

auto expandTy = expandOp.getType().dyn_cast<RankedTensorType>();
auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
if (!expandTy)
return failure();
ArrayRef<int64_t> dstShape = expandTy.getShape();
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2256,7 +2256,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<OpFoldResult> outputShape) {
auto [staticOutputShape, dynamicOutputShape] =
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
build(builder, result, resultType.cast<MemRefType>(), src,
build(builder, result, llvm::cast<MemRefType>(resultType), src,
getReassociationIndicesAttribute(builder, reassociation),
dynamicOutputShape, staticOutputShape);
}
Expand All @@ -2266,7 +2266,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<ReassociationIndices> reassociation) {
SmallVector<OpFoldResult> inputShape =
getMixedSizes(builder, result.location, src);
MemRefType memrefResultTy = resultType.cast<MemRefType>();
MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
builder, result.location, memrefResultTy, reassociation, inputShape);
// Failure of this assertion usually indicates presence of multiple
Expand Down Expand Up @@ -2867,7 +2867,8 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
/// marked as dropped in `droppedDims`.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
const llvm::SmallBitVector &droppedDims) {
assert(size_t(t1.getRank()) == droppedDims.size() && "incorrect number of bits");
assert(size_t(t1.getRank()) == droppedDims.size() &&
"incorrect number of bits");
assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
"incorrect number of dropped dims");
int64_t t1Offset, t2Offset;
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<OpFoldResult> outputShape) {
auto [staticOutputShape, dynamicOutputShape] =
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
build(builder, result, resultType.cast<RankedTensorType>(), src,
build(builder, result, cast<RankedTensorType>(resultType), src,
getReassociationIndicesAttribute(builder, reassociation),
dynamicOutputShape, staticOutputShape);
}
Expand All @@ -1673,7 +1673,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<ReassociationIndices> reassociation) {
SmallVector<OpFoldResult> inputShape =
getMixedSizes(builder, result.location, src);
auto tensorResultTy = resultType.cast<RankedTensorType>();
auto tensorResultTy = cast<RankedTensorType>(resultType);
FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
builder, result.location, tensorResultTy, reassociation, inputShape);
// Failure of this assertion usually indicates presence of multiple
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ static bool hasZeroDimension(ShapedType shapedType) {
return false;
}

template <typename T> static LogicalResult verifyConvOp(T op) {
template <typename T>
static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
Expand Down Expand Up @@ -962,7 +963,7 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
return emitOpError() << "tensor has a dimension with size zero. Each "
"dimension of a tensor must have size >= 1";

if ((int64_t) getNewShape().size() != outputType.getRank())
if ((int64_t)getNewShape().size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";

for (auto [newShapeDim, outputShapeDim] :
Expand Down Expand Up @@ -1127,7 +1128,7 @@ LogicalResult TransposeOp::reifyResultShapes(
return failure();

Value input = getInput1();
auto inputType = input.getType().cast<TensorType>();
auto inputType = cast<TensorType>(input.getType());

SmallVector<OpFoldResult> returnedDims(inputType.getRank());
for (auto dim : transposePerms) {
Expand Down
19 changes: 10 additions & 9 deletions mlir/lib/Tools/PDLL/AST/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Type Type::refineWith(Type other) const {
return *this;

// Operation types are compatible if the operation names don't conflict.
if (auto opTy = dyn_cast<OperationType>()) {
auto otherOpTy = other.dyn_cast<ast::OperationType>();
if (auto opTy = mlir::dyn_cast<OperationType>(*this)) {
auto otherOpTy = mlir::dyn_cast<ast::OperationType>(other);
if (!otherOpTy)
return nullptr;
if (!otherOpTy.getName())
Expand Down Expand Up @@ -105,25 +105,26 @@ Type RangeType::getElementType() const {
// TypeRangeType

bool TypeRangeType::classof(Type type) {
RangeType range = type.dyn_cast<RangeType>();
return range && range.getElementType().isa<TypeType>();
RangeType range = mlir::dyn_cast<RangeType>(type);
return range && mlir::isa<TypeType>(range.getElementType());
}

TypeRangeType TypeRangeType::get(Context &context) {
return RangeType::get(context, TypeType::get(context)).cast<TypeRangeType>();
return mlir::cast<TypeRangeType>(
RangeType::get(context, TypeType::get(context)));
}

//===----------------------------------------------------------------------===//
// ValueRangeType

bool ValueRangeType::classof(Type type) {
RangeType range = type.dyn_cast<RangeType>();
return range && range.getElementType().isa<ValueType>();
RangeType range = mlir::dyn_cast<RangeType>(type);
return range && mlir::isa<ValueType>(range.getElementType());
}

ValueRangeType ValueRangeType::get(Context &context) {
return RangeType::get(context, ValueType::get(context))
.cast<ValueRangeType>();
return mlir::cast<ValueRangeType>(
RangeType::get(context, ValueType::get(context)));
}

//===----------------------------------------------------------------------===//
Expand Down
24 changes: 12 additions & 12 deletions mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,13 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
// Generate a value based on the type of the variable.
ast::Type type = varDecl->getType();
Type mlirType = genType(type);
if (type.isa<ast::ValueType>())
if (isa<ast::ValueType>(type))
return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
if (type.isa<ast::TypeType>())
if (isa<ast::TypeType>(type))
return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
if (type.isa<ast::AttributeType>())
if (isa<ast::AttributeType>(type))
return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
if (ast::OperationType opType = dyn_cast<ast::OperationType>(type)) {
Value operands = builder.create<pdl::OperandsOp>(
loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
/*type=*/Value());
Expand All @@ -354,12 +354,12 @@ Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
loc, opType.getName(), operands, std::nullopt, ValueRange(), results);
}

if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
if (ast::RangeType rangeTy = dyn_cast<ast::RangeType>(type)) {
ast::Type eleTy = rangeTy.getElementType();
if (eleTy.isa<ast::ValueType>())
if (isa<ast::ValueType>(eleTy))
return builder.create<pdl::OperandsOp>(loc, mlirType,
getTypeConstraint());
if (eleTy.isa<ast::TypeType>())
if (isa<ast::TypeType>(eleTy))
return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
}

Expand Down Expand Up @@ -440,7 +440,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
ast::Type parentType = expr->getParentExpr()->getType();

// Handle operation based member access.
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
Type mlirType = genType(expr->getType());
if (isa<pdl::ValueType>(mlirType))
Expand Down Expand Up @@ -480,7 +480,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
}

// Handle tuple based member access.
if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
auto elementNames = tupleType.getElementNames();

// The index is either a numeric index, or a name.
Expand Down Expand Up @@ -581,14 +581,14 @@ CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
if (!cstBody) {
ast::Type declResultType = decl->getResultType();
SmallVector<Type> resultTypes;
if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(declResultType)) {
for (ast::Type type : tupleType.getElementTypes())
resultTypes.push_back(genType(type));
} else {
resultTypes.push_back(genType(declResultType));
}
PDLOpT pdlOp = builder.create<PDLOpT>(
loc, resultTypes, decl->getName().getName(), inputs);
PDLOpT pdlOp = builder.create<PDLOpT>(loc, resultTypes,
decl->getName().getName(), inputs);
if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
return pdlOp->getResults();
Expand Down
32 changes: 16 additions & 16 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ LogicalResult Parser::convertExpressionTo(
return diag;
};

if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
if (auto exprOpType = dyn_cast<ast::OperationType>(exprType))
return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);

// FIXME: Decide how to allow/support converting a single result to multiple,
Expand All @@ -638,7 +638,7 @@ LogicalResult Parser::convertExpressionTo(
return success();

// Handle tuple types.
if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
if (auto exprTupleType = dyn_cast<ast::TupleType>(exprType))
return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
noteAttachFn);

Expand All @@ -650,7 +650,7 @@ LogicalResult Parser::convertOpExpressionTo(
function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
// Two operation types are compatible if they have the same name, or if the
// expected type is more general.
if (auto opType = type.dyn_cast<ast::OperationType>()) {
if (auto opType = dyn_cast<ast::OperationType>(type)) {
if (opType.getName())
return emitErrorFn();
return success();
Expand Down Expand Up @@ -702,7 +702,7 @@ LogicalResult Parser::convertTupleExpressionTo(
function_ref<ast::InFlightDiagnostic()> emitErrorFn,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
// Handle conversions between tuples.
if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
if (auto tupleType = dyn_cast<ast::TupleType>(type)) {
if (tupleType.size() != exprType.size())
return emitErrorFn();

Expand Down Expand Up @@ -2568,7 +2568,7 @@ Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
}

// Constraint types cannot be used when defining variables.
if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
if (isa<ast::ConstraintType, ast::RewriteType>(type)) {
return emitError(
loc, llvm::formatv("unable to define variable of `{0}` type", type));
}
Expand Down Expand Up @@ -2782,7 +2782,7 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
StringRef name, SMRange loc) {
ast::Type parentType = parentExpr->getType();
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType)) {
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
return valueRangeTy;

Expand All @@ -2808,7 +2808,7 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
// operations. It returns a single value.
return valueTy;
}
} else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
} else if (auto tupleType = dyn_cast<ast::TupleType>(parentType)) {
// Handle indexed results.
unsigned index = 0;
if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
Expand Down Expand Up @@ -2845,7 +2845,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
for (ast::NamedAttributeDecl *attr : attributes) {
// Check for an attribute type, or a type awaiting resolution.
ast::Type attrType = attr->getValue()->getType();
if (!attrType.isa<ast::AttributeType>()) {
if (!isa<ast::AttributeType>(attrType)) {
return emitError(
attr->getValue()->getLoc(),
llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
Expand Down Expand Up @@ -3024,7 +3024,7 @@ LogicalResult Parser::validateOperationOperandsOrResults(
// ValueRange. This situations arises quite often with nested operation
// expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
if (singleTy == valueTy) {
if (valueExprType.isa<ast::OperationType>()) {
if (isa<ast::OperationType>(valueExprType)) {
valueExpr = convertOpToValue(valueExpr);
continue;
}
Expand All @@ -3048,7 +3048,7 @@ Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames) {
for (const ast::Expr *element : elements) {
ast::Type eleTy = element->getType();
if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
if (isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>(eleTy)) {
return emitError(
element->getLoc(),
llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
Expand All @@ -3064,7 +3064,7 @@ FailureOr<ast::EraseStmt *> Parser::createEraseStmt(SMRange loc,
ast::Expr *rootOp) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>())
if (!isa<ast::OperationType>(rootType))
return emitError(rootOp->getLoc(), "expected `Op` expression");

return ast::EraseStmt::create(ctx, loc, rootOp);
Expand All @@ -3075,7 +3075,7 @@ Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
if (!isa<ast::OperationType>(rootType)) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
Expand All @@ -3088,7 +3088,7 @@ Parser::createReplaceStmt(SMRange loc, ast::Expr *rootOp,
ast::Type replType = replExpr->getType();

// Check that replExpr is an Operation, Value, or ValueRange.
if (replType.isa<ast::OperationType>()) {
if (isa<ast::OperationType>(replType)) {
if (shouldConvertOpToValues)
replExpr = convertOpToValue(replExpr);
continue;
Expand All @@ -3110,7 +3110,7 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
if (!isa<ast::OperationType>(rootType)) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
Expand All @@ -3125,9 +3125,9 @@ Parser::createRewriteStmt(SMRange loc, ast::Expr *rootOp,

LogicalResult Parser::codeCompleteMemberAccess(ast::Expr *parentExpr) {
ast::Type parentType = parentExpr->getType();
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>())
if (ast::OperationType opType = dyn_cast<ast::OperationType>(parentType))
codeCompleteContext->codeCompleteOperationMemberAccess(opType);
else if (ast::TupleType tupleType = parentType.dyn_cast<ast::TupleType>())
else if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(parentType))
codeCompleteContext->codeCompleteTupleMemberAccess(tupleType);
return failure();
}
Expand Down
Loading

0 comments on commit d235369

Please sign in to comment.