diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp index 26e348f2a30ec..83e2dcd750bb3 100644 --- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp @@ -142,7 +142,7 @@ mlir::transform::HasOperandSatisfyingOp::apply( transform::detail::prepareValueMappings( yieldedMappings, getBody().front().getTerminator()->getOperands(), state); - results.setParams(getPosition().cast(), + results.setParams(cast(getPosition()), {rewriter.getI32IntegerAttr(operand.getOperandNumber())}); for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings)) results.setMappedValues(result, mapping); diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h index ab4df2ab028d4..5e4b4f3a66af9 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h @@ -87,7 +87,7 @@ struct IndependentParallelIteratorDomainShardingInterface void populateIteratorTypes(Type t, SmallVector &iterTypes) const { - RankedTensorType rankedTensorType = t.dyn_cast(); + RankedTensorType rankedTensorType = dyn_cast(t); if (!rankedTensorType) { return; } @@ -106,7 +106,7 @@ struct ElementwiseShardingInterface ElementwiseShardingInterface, ElemwiseOp> { SmallVector getLoopIteratorTypes(Operation *op) const { Value val = op->getOperand(0); - auto type = val.getType().dyn_cast(); + auto type = dyn_cast(val.getType()); if (!type) return {}; SmallVector types(type.getRank(), @@ -117,7 +117,7 @@ struct ElementwiseShardingInterface SmallVector getIndexingMaps(Operation *op) const { MLIRContext *ctx = op->getContext(); Value val = op->getOperand(0); - auto type = val.getType().dyn_cast(); + auto type = dyn_cast(val.getType()); if (!type) return {}; int64_t rank = type.getRank(); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index a9bc3351f4cff..ec3c2cb011c35 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -60,11 +60,11 @@ class MulOperandsAndResultElementType if (llvm::isa(resElemType)) return impl::verifySameOperandsAndResultElementType(op); - if (auto resIntType = resElemType.dyn_cast()) { + if (auto resIntType = dyn_cast(resElemType)) { IntegerType lhsIntType = - getElementTypeOrSelf(op->getOperand(0)).cast(); + cast(getElementTypeOrSelf(op->getOperand(0))); IntegerType rhsIntType = - getElementTypeOrSelf(op->getOperand(1)).cast(); + cast(getElementTypeOrSelf(op->getOperand(1))); if (lhsIntType != rhsIntType) return op->emitOpError( "requires the same element type for all operands"); diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h index d4268e804f4f7..aa8314f38cdfa 100644 --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -154,7 +154,7 @@ class FusedLocWith : public FusedLoc { /// Support llvm style casting. static bool classof(Attribute attr) { auto fusedLoc = llvm::dyn_cast(attr); - return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull(); + return fusedLoc && mlir::isa_and_nonnull(fusedLoc.getMetadata()); } }; diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 4669c40f843d9..21c66f38a8af0 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -135,7 +135,7 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations, unwrap(ctx), llvm::map_to_vector( unwrapList(nOperations, operations, attrStorage), - [](Attribute a) { return a.cast(); }))); + [](Attribute a) { return cast(a); }))); } MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) { @@ -165,7 +165,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( cast(unwrap(scope)), cast(unwrap(baseType)), DIFlags(flags), sizeInBits, alignInBits, llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), - [](Attribute a) { return a.cast(); }))); + [](Attribute a) { return cast(a); }))); } MlirAttribute @@ -259,7 +259,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, return wrap(DISubroutineTypeAttr::get( unwrap(ctx), callingConvention, llvm::map_to_vector(unwrapList(nTypes, types, attrStorage), - [](Attribute a) { return a.cast(); }))); + [](Attribute a) { return cast(a); }))); } MlirAttribute mlirLLVMDISubprogramAttrGet( diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index e1a5d82587cf9..c94c070144a7e 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -311,11 +311,11 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, } bool mlirVectorTypeIsScalable(MlirType type) { - return unwrap(type).cast().isScalable(); + return cast(unwrap(type)).isScalable(); } bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) { - return unwrap(type).cast().getScalableDims()[dim]; + return cast(unwrap(type)).getScalableDims()[dim]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 7e073bae75c0c..033e66c6118f3 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -371,7 +371,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, bool isUnsigned, Value llvmInput, SmallVector &operands) { Type inputType = llvmInput.getType(); - auto vectorType = inputType.dyn_cast(); + auto vectorType = dyn_cast(inputType); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) @@ -414,7 +414,7 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Value output, int32_t subwordOffset, bool clamp, SmallVector &operands) { Type inputType = output.getType(); - auto vectorType = inputType.dyn_cast(); + auto vectorType = dyn_cast(inputType); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) output = rewriter.create( @@ -569,9 +569,8 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, /// on the architecture you are compiling for. static std::optional wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset) { - - auto sourceVectorType = wmma.getSourceA().getType().dyn_cast(); - auto destVectorType = wmma.getDestC().getType().dyn_cast(); + auto sourceVectorType = dyn_cast(wmma.getSourceA().getType()); + auto destVectorType = dyn_cast(wmma.getDestC().getType()); auto elemSourceType = sourceVectorType.getElementType(); auto elemDestType = destVectorType.getElementType(); @@ -727,7 +726,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( Type f32 = getTypeConverter()->convertType(op.getResult().getType()); Value source = adaptor.getSource(); - auto sourceVecType = op.getSource().getType().dyn_cast(); + auto sourceVecType = dyn_cast(op.getSource().getType()); Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 0113a3df0b8e3..3d3ff001c541b 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -65,7 +65,7 @@ static Value castF32To(Type elementType, Value f32, Location loc, LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { Type inType = op.getIn().getType(); - if (auto inVecType = inType.dyn_cast()) { + if (auto inVecType = dyn_cast(inType)) { if (inVecType.isScalable()) return failure(); if (inVecType.getShape().size() > 1) @@ -81,13 +81,13 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); - if (!in.getType().isa()) { + if (!isa(in.getType())) { Value asFloat = rewriter.create( loc, rewriter.getF32Type(), in, 0); Value result = castF32To(outElemType, asFloat, loc, rewriter); return rewriter.replaceOp(op, result); } - VectorType inType = in.getType().cast(); + VectorType inType = cast(in.getType()); int64_t numElements = inType.getNumElements(); Value zero = rewriter.create( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); @@ -179,7 +179,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { if (op.getRoundingmodeAttr()) return failure(); Type outType = op.getOut().getType(); - if (auto outVecType = outType.dyn_cast()) { + if (auto outVecType = dyn_cast(outType)) { if (outVecType.isScalable()) return failure(); if (outVecType.getShape().size() > 1) @@ -202,7 +202,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, if (saturateFP8) in = clampInput(rewriter, loc, outElemType, in); VectorType truncResType = VectorType::get(4, outElemType); - if (!in.getType().isa()) { + if (!isa(in.getType())) { Value asFloat = castToF32(in, loc, rewriter); Value asF8s = rewriter.create( loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, @@ -210,7 +210,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, Value result = rewriter.create(loc, asF8s, 0); return rewriter.replaceOp(op, result); } - VectorType outType = op.getOut().getType().cast(); + VectorType outType = cast(op.getOut().getType()); int64_t numElements = outType.getNumElements(); Value zero = rewriter.create( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 993c09b03c0fd..36e10372e4bc5 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -214,7 +214,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, llvm::enumerate(gpuFuncOp.getArgumentTypes())) { auto remapping = signatureConversion.getInputMapping(idx); NamedAttrList argAttr = - argAttrs ? argAttrs[idx].cast() : NamedAttrList(); + argAttrs ? cast(argAttrs[idx]) : NamedAttrList(); auto copyAttribute = [&](StringRef attrName) { Attribute attr = argAttr.erase(attrName); if (!attr) @@ -234,9 +234,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, return; } for (size_t i = 0, e = remapping->size; i < e; ++i) { - if (llvmFuncOp.getArgument(remapping->inputNo + i) - .getType() - .isa()) { + if (isa( + llvmFuncOp.getArgument(remapping->inputNo + i).getType())) { llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); } } diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 78d4e80624687..3a4fc7d8063f4 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -668,7 +668,7 @@ static int32_t getCuSparseLtDataTypeFrom(Type type) { static int32_t getCuSparseDataTypeFrom(Type type) { if (llvm::isa(type)) { // get the element type - auto elementType = type.cast().getElementType(); + auto elementType = cast(type).getElementType(); if (elementType.isBF16()) return 15; // CUDA_C_16BF if (elementType.isF16()) diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 9b5d19ebd783a..11d29754aa760 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1579,7 +1579,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering if (offset) ti = makeAdd(ti, makeConst(offset)); - auto structType = matrixD.getType().cast(); + auto structType = cast(matrixD.getType()); // Number of 32-bit registers owns per thread constexpr unsigned numAdjacentRegisters = 2; @@ -1606,9 +1606,9 @@ struct NVGPUWarpgroupMmaStoreOpLowering int offset = 0; ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value matriDValue = adaptor.getMatrixD(); - auto stype = matriDValue.getType().cast(); + auto stype = cast(matriDValue.getType()); for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { - auto structType = matrixD.cast(); + auto structType = cast(matrixD); Value innerStructValue = b.create(matriDValue, idx); storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); offset += structType.getBody().size(); @@ -1626,13 +1626,9 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op->getLoc(), rewriter); - LLVM::LLVMStructType packStructType = - getTypeConverter() - ->convertType(op.getMatrixC().getType()) - .cast(); - Type elemType = packStructType.getBody() - .front() - .cast() + LLVM::LLVMStructType packStructType = cast( + getTypeConverter()->convertType(op.getMatrixC().getType())); + Type elemType = cast(packStructType.getBody().front()) .getBody() .front(); Value zero = b.create(elemType, b.getZeroAttr(elemType)); @@ -1640,7 +1636,7 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering SmallVector innerStructs; // Unpack the structs and set all values to zero for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { - auto structType = s.cast(); + auto structType = cast(s); Value structValue = b.create(packStruct, idx); for (unsigned i = 0; i < structType.getBody().size(); ++i) { structValue = b.create( diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index ef8d59c9b2608..b6b85cab5a382 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -618,7 +618,7 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, static SmallVector expandInputRanks(PatternRewriter &rewriter, Location loc, Operation *operation) { auto rank = - operation->getResultTypes().front().cast().getRank(); + cast(operation->getResultTypes().front()).getRank(); return llvm::map_to_vector(operation->getOperands(), [&](Value operand) { return expandRank(rewriter, loc, operand, rank); }); @@ -680,7 +680,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, // dimension, that is the target size. An occurrence of an additional static // dimension greater than 1 with a different value is undefined behavior. for (auto operand : operands) { - auto size = operand.getType().cast().getDimSize(dim); + auto size = cast(operand.getType()).getDimSize(dim); if (!ShapedType::isDynamic(size) && size > 1) return {rewriter.getIndexAttr(size), operand}; } @@ -688,7 +688,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, // Filter operands with dynamic dimension auto operandsWithDynamicDim = llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) { - return operand.getType().cast().isDynamicDim(dim); + return cast(operand.getType()).isDynamicDim(dim); })); // If no operand has a dynamic dimension, it means all sizes were 1 @@ -718,7 +718,7 @@ static std::pair, SmallVector> computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands) { assert(!operands.empty()); - auto rank = operands.front().getType().cast().getRank(); + auto rank = cast(operands.front().getType()).getRank(); SmallVector targetShape; SmallVector masterOperands; for (auto dim : llvm::seq(0, rank)) { @@ -735,7 +735,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, int64_t dim, OpFoldResult targetSize, Value masterOperand) { // Nothing to do if this is a static dimension - auto rankedTensorType = operand.getType().cast(); + auto rankedTensorType = cast(operand.getType()); if (!rankedTensorType.isDynamicDim(dim)) return operand; @@ -817,7 +817,7 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef targetShape, ArrayRef masterOperands) { - int64_t rank = operand.getType().cast().getRank(); + int64_t rank = cast(operand.getType()).getRank(); assert((int64_t)targetShape.size() == rank); assert((int64_t)masterOperands.size() == rank); for (auto index : llvm::seq(0, rank)) @@ -848,8 +848,7 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef targetShape) { // Generate output tensor - auto resultType = - operation->getResultTypes().front().cast(); + auto resultType = cast(operation->getResultTypes().front()); Value outputTensor = rewriter.create( loc, targetShape, resultType.getElementType()); @@ -2274,8 +2273,7 @@ struct RFFT2dConverter final : public OpRewritePattern { llvm::SmallVector staticSizes; dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); - auto elementType = - input.getType().cast().getElementType(); + auto elementType = cast(input.getType()).getElementType(); return RankedTensorType::get(staticSizes, elementType); } @@ -2327,7 +2325,7 @@ struct RFFT2dConverter final : public OpRewritePattern { auto loc = rfft2d.getLoc(); auto input = rfft2d.getInput(); auto elementType = - input.getType().cast().getElementType().cast(); + cast(cast(input.getType()).getElementType()); // Compute the output type and set of dynamic sizes llvm::SmallVector dynamicSizes; diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 399c0450824ee..3f92372d7cea9 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -1204,10 +1204,10 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op, return rewriter.notifyMatchFailure(op, "no mapping"); matrixOperands.push_back(it->second); } - auto resultType = matrixOperands[0].getType().cast(); + auto resultType = cast(matrixOperands[0].getType()); if (opType == gpu::MMAElementwiseOp::EXTF) { // The floating point extension case has a different result type. - auto vectorType = op->getResultTypes()[0].cast(); + auto vectorType = cast(op->getResultTypes()[0]); resultType = gpu::MMAMatrixType::get(resultType.getShape(), vectorType.getElementType(), resultType.getOperand()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 85d10f326e260..1b9975237c699 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -631,8 +631,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, Type vectorType) { const auto &floatSemantics = cast(llvmType).getFloatSemantics(); auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); - auto denseValue = - DenseElementsAttr::get(vectorType.cast(), value); + auto denseValue = DenseElementsAttr::get(cast(vectorType), value); return rewriter.create(loc, vectorType, denseValue); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 2575ad4984814..e3beceaa3bbb5 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -227,8 +227,8 @@ LogicalResult WMMAOp::verify() { Type sourceAType = getSourceA().getType(); Type destType = getDestC().getType(); - VectorType sourceVectorAType = sourceAType.dyn_cast(); - VectorType destVectorType = destType.dyn_cast(); + VectorType sourceVectorAType = dyn_cast(sourceAType); + VectorType destVectorType = dyn_cast(destType); Type sourceAElemType = sourceVectorAType.getElementType(); Type destElemType = destVectorType.getElementType(); diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index d7492c9e25db3..5e69a98db8f1e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -26,7 +26,7 @@ struct ConstantOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto constantOp = cast(op); - auto type = constantOp.getType().dyn_cast(); + auto type = dyn_cast(constantOp.getType()); // Only ranked tensors are supported. if (!type) diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index b9ab95b92496e..4a50da3513f99 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -106,7 +106,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions( targetType](Type type) -> std::optional { if (llvm::is_contained(sourceTypes, type)) return targetType; - if (auto shaped = type.dyn_cast()) + if (auto shaped = dyn_cast(type)) if (llvm::is_contained(sourceTypes, shaped.getElementType())) return shaped.clone(targetType); // All other types legal diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 3ae894692089b..7e390aa551972 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp @@ -99,7 +99,7 @@ class LowerContractionToSMMLAPattern Value extsiLhs; Value extsiRhs; if (auto lhsExtInType = - origLhsExtOp.getIn().getType().dyn_cast()) { + dyn_cast(origLhsExtOp.getIn().getType())) { if (lhsExtInType.getElementTypeBitWidth() <= 8) { Type targetLhsExtTy = matchContainerType(rewriter.getI8Type(), lhsExtInType); @@ -108,7 +108,7 @@ class LowerContractionToSMMLAPattern } } if (auto rhsExtInType = - origRhsExtOp.getIn().getType().dyn_cast()) { + dyn_cast(origRhsExtOp.getIn().getType())) { if (rhsExtInType.getElementTypeBitWidth() <= 8) { Type targetRhsExtTy = matchContainerType(rewriter.getI8Type(), rhsExtInType); @@ -161,9 +161,9 @@ class LowerContractionToSMMLAPattern extractOperand(op.getAcc(), accPermutationMap, accOffsets); auto inputElementType = - tiledLhs.getType().cast().getElementType(); + cast(tiledLhs.getType()).getElementType(); auto accElementType = - tiledAcc.getType().cast().getElementType(); + cast(tiledAcc.getType()).getElementType(); auto inputExpandedType = VectorType::get({2, 8}, inputElementType); auto outputExpandedType = VectorType::get({2, 2}, accElementType); @@ -175,9 +175,9 @@ class LowerContractionToSMMLAPattern auto emptyOperand = rewriter.create( loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); SmallVector offsets( - emptyOperand.getType().cast().getRank(), 0); + cast(emptyOperand.getType()).getRank(), 0); SmallVector strides( - tiledOperand.getType().cast().getRank(), 1); + cast(tiledOperand.getType()).getRank(), 1); return rewriter.createOrFold( loc, tiledOperand, emptyOperand, offsets, strides); }; @@ -214,7 +214,7 @@ class LowerContractionToSMMLAPattern // Insert the tiled result back into the non tiled result of the // contract op. SmallVector strides( - tiledRes.getType().cast().getRank(), 1); + cast(tiledRes.getType()).getRank(), 1); result = rewriter.createOrFold( loc, tiledRes, result, accOffsets, strides); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp index a5ea42b7d701d..b197786c32054 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp @@ -39,7 +39,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { return builder.create(loc, builder.getBoolAttr(value)); } -static bool isMemref(Value v) { return v.getType().isa(); } +static bool isMemref(Value v) { return isa(v.getType()); } //===----------------------------------------------------------------------===// // Ownership @@ -222,8 +222,8 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const { return false; // Block arguments are less than results. - bool lhsIsBBArg = lhs.isa(); - if (lhsIsBBArg != rhs.isa()) { + bool lhsIsBBArg = isa(lhs); + if (lhsIsBBArg != isa(rhs)) { return lhsIsBBArg; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index c2b2b99fc0083..d51d63f243ea0 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -684,7 +684,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, // Op is not bufferizable. auto memSpace = - options.defaultMemorySpaceFn(value.getType().cast()); + options.defaultMemorySpaceFn(cast(value.getType())); if (!memSpace.has_value()) return op->emitError("could not infer memory space"); @@ -939,7 +939,7 @@ FailureOr bufferization::detail::defaultGetBufferType( // If we do not know the memory space and there is no default memory space, // report a failure. auto memSpace = - options.defaultMemorySpaceFn(value.getType().cast()); + options.defaultMemorySpaceFn(cast(value.getType())); if (!memSpace.has_value()) return op->emitError("could not infer memory space"); @@ -987,7 +987,7 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) { for (Region ®ion : opOperand.getOwner()->getRegions()) if (!region.getBlocks().empty()) for (BlockArgument bbArg : region.getBlocks().front().getArguments()) - if (bbArg.getType().isa()) + if (isa(bbArg.getType())) r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index c9fd110d48d9a..a8ec111f8c304 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -46,7 +46,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { return builder.create(loc, builder.getBoolAttr(value)); } -static bool isMemref(Value v) { return v.getType().isa(); } +static bool isMemref(Value v) { return isa(v.getType()); } /// Return "true" if the given op is guaranteed to have neither "Allocate" nor /// "Free" side effects. diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 1c81433bc3e94..fb97045687d65 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -378,7 +378,7 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) { if (!rhs) return {}; - ArrayAttr arrayAttr = rhs.dyn_cast(); + ArrayAttr arrayAttr = dyn_cast(rhs); if (!arrayAttr || arrayAttr.size() != 2) return {}; diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp index 0dc357c2298fa..89546da428fa2 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -17,7 +17,7 @@ using namespace mlir; using namespace mlir::bufferization; -static bool isMemref(Value v) { return v.getType().isa(); } +static bool isMemref(Value v) { return isa(v.getType()); } namespace { /// While CondBranchOp also implement the BranchOpInterface, we add a diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index b037ef3c0b415..66a71df29a9bb 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -160,13 +160,13 @@ LogicalResult AddOp::verify() { Type lhsType = getLhs().getType(); Type rhsType = getRhs().getType(); - if (lhsType.isa() && rhsType.isa()) + if (isa(lhsType) && isa(rhsType)) return emitOpError("requires that at most one operand is a pointer"); - if ((lhsType.isa() && - !rhsType.isa()) || - (rhsType.isa() && - !lhsType.isa())) + if ((isa(lhsType) && + !isa(rhsType)) || + (isa(rhsType) && + !isa(lhsType))) return emitOpError("requires that one operand is an integer or of opaque " "type if the other is a pointer"); @@ -778,16 +778,16 @@ LogicalResult SubOp::verify() { Type rhsType = getRhs().getType(); Type resultType = getResult().getType(); - if (rhsType.isa() && !lhsType.isa()) + if (isa(rhsType) && !isa(lhsType)) return emitOpError("rhs can only be a pointer if lhs is a pointer"); - if (lhsType.isa() && - !rhsType.isa()) + if (isa(lhsType) && + !isa(rhsType)) return emitOpError("requires that rhs is an integer, pointer or of opaque " "type if lhs is a pointer"); - if (lhsType.isa() && rhsType.isa() && - !resultType.isa()) + if (isa(lhsType) && isa(rhsType) && + !isa(resultType)) return emitOpError("requires that the result is an integer or of opaque " "type if lhs and rhs are pointers"); return success(); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index b584f63f16e0a..3661c5dea4525 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -196,7 +196,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { auto extract = dyn_cast(users); if (!extract) return std::nullopt; - auto vecType = extract.getResult().getType().cast(); + auto vecType = cast(extract.getResult().getType()); if (sliceType && sliceType != vecType) return std::nullopt; sliceType = vecType; @@ -204,7 +204,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { return llvm::to_vector(sliceType.getShape()); } if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { - if (auto vecType = op->getResultTypes()[0].dyn_cast()) { + if (auto vecType = dyn_cast(op->getResultTypes()[0])) { // TODO: The condition for unrolling elementwise should be restricted // only to operations that need unrolling (connected to the contract). if (vecType.getRank() < 2) @@ -219,7 +219,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { auto extract = dyn_cast(users); if (!extract) return std::nullopt; - auto vecType = extract.getResult().getType().cast(); + auto vecType = cast(extract.getResult().getType()); if (sliceType && sliceType != vecType) return std::nullopt; sliceType = vecType; diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp index 9ab7ae2a90820..cfc8d092c8178 100644 --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -354,7 +354,7 @@ static WalkResult loadOperation( // Gather the variadicities of each result for (Attribute attr : resultsOp->getVariadicity()) - resultVariadicity.push_back(attr.cast().getValue()); + resultVariadicity.push_back(cast(attr).getValue()); } // Gather which constraint slots correspond to attributes constraints @@ -367,7 +367,7 @@ static WalkResult loadOperation( for (const auto &[name, value] : llvm::zip(names, values)) { for (auto [i, constr] : enumerate(constrToValue)) { if (constr == value) { - attributesContraints[name.cast()] = i; + attributesContraints[cast(name)] = i; break; } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index f3b674fdb5050..f7f1e944d637d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -42,7 +42,7 @@ static char getRegisterType(Type type) { return 'f'; if (type.isF64()) return 'd'; - if (auto ptr = type.dyn_cast()) { + if (auto ptr = dyn_cast(type)) { // Shared address spaces is addressed with 32-bit pointers. if (ptr.getAddressSpace() == kSharedMemorySpace) { return 'r'; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index f90240a67dcc5..1db506e286b3c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -559,7 +559,7 @@ static void destructureIndices(Type currType, ArrayRef indices, // we don't do anything here. The verifier will catch it and emit a proper // error. All other canonicalization is done in the fold method. bool requiresConst = !rawConstantIndices.empty() && - currType.isa_and_nonnull(); + isa_and_nonnull(currType); if (Value val = llvm::dyn_cast_if_present(iter)) { APInt intC; if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) && @@ -2564,14 +2564,14 @@ LogicalResult LLVM::ConstantOp::verify() { } // See the comment for getLLVMConstant for more details about why 8-bit // floats can be represented by integers. - if (getType().isa() && !getType().isInteger(floatWidth)) { + if (isa(getType()) && !getType().isInteger(floatWidth)) { return emitOpError() << "expected integer type of width " << floatWidth; } } if (auto splatAttr = dyn_cast(getValue())) { - if (!getType().isa() && !getType().isa() && - !getType().isa() && - !getType().isa()) + if (!isa(getType()) && !isa(getType()) && + !isa(getType()) && + !isa(getType())) return emitOpError() << "expected vector or array type"; } return success(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 93901477b5820..f2ab3eae2c343 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -319,7 +319,7 @@ LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses( static Type getTypeAtIndex(const DestructurableMemorySlot &slot, Attribute index) { auto subelementIndexMap = - slot.elemType.cast().getSubelementIndexMap(); + cast(slot.elemType).getSubelementIndexMap(); if (!subelementIndexMap) return {}; assert(!subelementIndexMap->empty()); @@ -913,8 +913,7 @@ bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot, if (getIsVolatile()) return false; - if (!slot.elemType.cast() - .getSubelementIndexMap()) + if (!cast(slot.elemType).getSubelementIndexMap()) return false; if (!areAllIndicesI32(slot)) @@ -928,7 +927,7 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot, RewriterBase &rewriter, const DataLayout &dataLayout) { std::optional> types = - slot.elemType.cast().getSubelementIndexMap(); + cast(slot.elemType).getSubelementIndexMap(); IntegerAttr memsetLenAttr; bool successfulMatch = @@ -1047,8 +1046,7 @@ static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot, if (op.getIsVolatile()) return false; - if (!slot.elemType.cast() - .getSubelementIndexMap()) + if (!cast(slot.elemType).getSubelementIndexMap()) return false; if (!areAllIndicesI32(slot)) diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp index b264e9ff9283d..0a372ad0c52fc 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -475,7 +475,7 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store, } } - auto destructurableType = typeHint.dyn_cast(); + auto destructurableType = dyn_cast(typeHint); if (!destructurableType) return failure(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp index 3e85559e1ec0c..768df0953fc5c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -202,9 +202,9 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( body, [&](Operation *elem, Operation *red) { return elem->getName().getStringRef() == - (*contractionOps)[0].cast().getValue() && + cast((*contractionOps)[0]).getValue() && red->getName().getStringRef() == - (*contractionOps)[1].cast().getValue(); + cast((*contractionOps)[1]).getValue(); }, os); if (result) @@ -259,11 +259,11 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation( return builder.getI64IntegerAttr(value); })); }; - results.setParams(getBatch().cast(), + results.setParams(cast(getBatch()), makeI64Attrs(contractionDims->batch)); - results.setParams(getM().cast(), makeI64Attrs(contractionDims->m)); - results.setParams(getN().cast(), makeI64Attrs(contractionDims->n)); - results.setParams(getK().cast(), makeI64Attrs(contractionDims->k)); + results.setParams(cast(getM()), makeI64Attrs(contractionDims->m)); + results.setParams(cast(getN()), makeI64Attrs(contractionDims->n)); + results.setParams(cast(getK()), makeI64Attrs(contractionDims->k)); return DiagnosedSilenceableFailure::success(); } @@ -288,17 +288,17 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( return builder.getI64IntegerAttr(value); })); }; - results.setParams(getBatch().cast(), + results.setParams(cast(getBatch()), makeI64Attrs(convolutionDims->batch)); - results.setParams(getOutputImage().cast(), + results.setParams(cast(getOutputImage()), makeI64Attrs(convolutionDims->outputImage)); - results.setParams(getOutputChannel().cast(), + results.setParams(cast(getOutputChannel()), makeI64Attrs(convolutionDims->outputChannel)); - results.setParams(getFilterLoop().cast(), + results.setParams(cast(getFilterLoop()), makeI64Attrs(convolutionDims->filterLoop)); - results.setParams(getInputChannel().cast(), + results.setParams(cast(getInputChannel()), makeI64Attrs(convolutionDims->inputChannel)); - results.setParams(getDepth().cast(), + results.setParams(cast(getDepth()), makeI64Attrs(convolutionDims->depth)); auto makeI64AttrsFromI64 = [&](ArrayRef values) { @@ -307,9 +307,9 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( return builder.getI64IntegerAttr(value); })); }; - results.setParams(getStrides().cast(), + results.setParams(cast(getStrides()), makeI64AttrsFromI64(convolutionDims->strides)); - results.setParams(getDilations().cast(), + results.setParams(cast(getDilations()), makeI64AttrsFromI64(convolutionDims->dilations)); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 7e7cf1d024461..3c3d968fbb865 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1219,7 +1219,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter, // All the operands must must be equal to the specified type auto typeattr = dyn_cast(getFilterOperandTypes().value()[0]); - Type t = typeattr.getValue().cast<::mlir::Type>(); + Type t = cast<::mlir::Type>(typeattr.getValue()); if (!llvm::all_of(op->getOperandTypes(), [&](Type operandType) { return operandType == t; })) return; @@ -1234,7 +1234,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter, for (auto [attr, operandType] : llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) { auto typeattr = cast(attr); - Type type = typeattr.getValue().cast<::mlir::Type>(); + Type type = cast<::mlir::Type>(typeattr.getValue()); if (type != operandType) return; @@ -2665,7 +2665,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, if (auto attr = llvm::dyn_cast_if_present(ofr)) { if (scalableSizes[ofrIdx]) { auto val = b.create( - getLoc(), attr.cast().getInt()); + getLoc(), cast(attr).getInt()); Value vscale = b.create(getLoc(), b.getIndexType()); sizes.push_back( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index b95677b7457e6..59c189fa1fbad 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -60,7 +60,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, const linalg::BufferizeToAllocationOptions &options) { auto tensorType = dyn_cast(tensorSource.getType()); assert(tensorType && "expected ranked tensor"); - assert(memrefDest.getType().isa() && "expected ranked memref"); + assert(isa(memrefDest.getType()) && "expected ranked memref"); switch (options.memcpyOp) { case linalg::BufferizeToAllocationOptions::MemcpyOp:: @@ -496,10 +496,10 @@ Value linalg::bufferizeToAllocation( if (op == nestedOp) return; if (llvm::any_of(nestedOp->getOperands(), - [](Value v) { return v.getType().isa(); })) + [](Value v) { return isa(v.getType()); })) llvm_unreachable("ops with nested tensor ops are not supported yet"); if (llvm::any_of(nestedOp->getResults(), - [](Value v) { return v.getType().isa(); })) + [](Value v) { return isa(v.getType()); })) llvm_unreachable("ops with nested tensor ops are not supported yet"); }); } @@ -508,7 +508,7 @@ Value linalg::bufferizeToAllocation( // Gather tensor results. SmallVector tensorResults; for (OpResult result : op->getResults()) { - if (!result.getType().isa()) + if (!isa(result.getType())) continue; // Unranked tensors are not supported if (!isa(result.getType())) diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp index 81669a1807796..4776883ed95c5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp @@ -49,7 +49,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( for (OpOperand *in : op.getDpsInputOperands()) { // Skip non-tensor operands. - if (!in->get().getType().isa()) + if (!isa(in->get().getType())) continue; // Find tensor.empty ops on the reverse SSA use-def chain. Only follow diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 462f692615faa..df4089d61bfd7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -405,7 +405,7 @@ static FailureOr tileToForallOpImpl( for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) { // Swap tensor inits with the corresponding block argument of the // scf.forall op. Memref inits remain as is. - if (outOperand.get().getType().isa()) { + if (isa(outOperand.get().getType())) { auto *it = llvm::find(dest, outOperand.get()); assert(it != dest.end() && "could not find destination tensor"); unsigned destNum = std::distance(dest.begin(), it); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index a17bc8e4cd318..2297bf5e35512 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -557,7 +557,7 @@ FailureOr linalg::pack(RewriterBase &rewriter, Value dest = tensor::PackOp::createDestinationTensor( rewriter, loc, operand, innerPackSizes, innerPos, /*outerDimsPerm=*/{}); - ShapedType operandType = operand.getType().cast(); + ShapedType operandType = cast(operand.getType()); bool areConstantTiles = llvm::all_of(innerPackSizes, [](OpFoldResult tile) { return getConstantIntValue(tile).has_value(); @@ -565,7 +565,7 @@ FailureOr linalg::pack(RewriterBase &rewriter, if (areConstantTiles && operandType.hasStaticShape() && !tensor::PackOp::requirePaddingValue( operandType.getShape(), innerPos, - dest.getType().cast().getShape(), {}, + cast(dest.getType()).getShape(), {}, innerPackSizes)) { packOps.push_back(rewriter.create( loc, operand, dest, innerPos, innerPackSizes)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index df61381432921..fbff2088637f4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -3410,8 +3410,8 @@ struct Conv1DGenerator // * shape_cast(broadcast(filter)) // * broadcast(shuffle(filter)) // Opt for the option without shape_cast to simplify the codegen. - auto rhsSize = rhs.getType().cast().getShape()[0]; - auto resSize = res.getType().cast().getShape()[1]; + auto rhsSize = cast(rhs.getType()).getShape()[0]; + auto resSize = cast(res.getType()).getShape()[1]; SmallVector indicies; for (int i = 0; i < resSize / rhsSize; ++i) { diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index b3481ce1c56bb..3c9475c2d143a 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -173,8 +173,8 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, } // Assemble results. - results.set(getGlobal().cast(), globalOps); - results.set(getGetGlobal().cast(), getGlobalOps); + results.set(cast(getGlobal()), globalOps); + results.set(cast(getGetGlobal()), getGlobalOps); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 8236a4c475f17..4449733f0daf0 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -254,7 +254,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto convertedType = adaptor.getMemref().getType().cast(); + auto convertedType = cast(adaptor.getMemref().getType()); auto convertedElementType = convertedType.getElementType(); auto oldElementType = op.getMemRefType().getElementType(); int srcBits = oldElementType.getIntOrFloatBitWidth(); @@ -351,7 +351,7 @@ struct ConvertMemrefStore final : OpConversionPattern { LogicalResult matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto convertedType = adaptor.getMemref().getType().cast(); + auto convertedType = cast(adaptor.getMemref().getType()); int srcBits = op.getMemRefType().getElementTypeBitWidth(); int dstBits = convertedType.getElementTypeBitWidth(); auto dstIntegerType = rewriter.getIntegerType(dstBits); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp index 62a8f7e43c867..dcc5eac916d03 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp @@ -68,7 +68,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern { // Get the size of the original buffer. int64_t inputSize = - op.getSource().getType().cast().getDimSize(0); + cast(op.getSource().getType()).getDimSize(0); OpFoldResult currSize = rewriter.getIndexAttr(inputSize); if (ShapedType::isDynamic(inputSize)) { Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc, @@ -79,7 +79,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern { // Get the requested size that the new buffer should have. int64_t outputSize = - op.getResult().getType().cast().getDimSize(0); + cast(op.getResult().getType()).getDimSize(0); OpFoldResult targetSize = ShapedType::isDynamic(outputSize) ? OpFoldResult{op.getDynamicResultSize()} : rewriter.getIndexAttr(outputSize); @@ -127,7 +127,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern { // is already bigger than the requested size, the cast represents a // subview operation. Value casted = builder.create( - loc, op.getResult().getType().cast(), op.getSource(), + loc, cast(op.getResult().getType()), op.getSource(), rewriter.getIndexAttr(0), ArrayRef{targetSize}, ArrayRef{rewriter.getIndexAttr(1)}); builder.create(loc, casted); diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 03f11ad1f9496..d4329b401df19 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -169,7 +169,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh, } Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) { - RankedTensorType rankedTensorType = type.dyn_cast(); + RankedTensorType rankedTensorType = dyn_cast(type); if (rankedTensorType) { return shardShapedType(rankedTensorType, mesh, sharding); } @@ -281,7 +281,8 @@ MeshShardingAttr::verify(function_ref emitError, } bool MeshShardingAttr::operator==(Attribute rhs) const { - MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast(); + MeshShardingAttr rhsAsMeshShardingAttr = + mlir::dyn_cast(rhs); return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr; } @@ -484,15 +485,15 @@ static LogicalResult verifyDimensionCompatibility(Location loc, static LogicalResult verifyGatherOperandAndResultShape( Value operand, Value result, int64_t gatherAxis, ArrayRef meshAxes, ArrayRef meshShape) { - auto resultRank = result.getType().template cast().getRank(); + auto resultRank = cast(result.getType()).getRank(); if (gatherAxis < 0 || gatherAxis >= resultRank) { return emitError(result.getLoc()) << "Gather axis " << gatherAxis << " is out of bounds [0, " << resultRank << ")."; } - ShapedType operandType = operand.getType().cast(); - ShapedType resultType = result.getType().cast(); + ShapedType operandType = cast(operand.getType()); + ShapedType resultType = cast(result.getType()); auto deviceGroupSize = DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { @@ -511,8 +512,8 @@ static LogicalResult verifyGatherOperandAndResultShape( static LogicalResult verifyAllToAllOperandAndResultShape( Value operand, Value result, int64_t splitAxis, int64_t concatAxis, ArrayRef meshAxes, ArrayRef meshShape) { - ShapedType operandType = operand.getType().cast(); - ShapedType resultType = result.getType().cast(); + ShapedType operandType = cast(operand.getType()); + ShapedType resultType = cast(result.getType()); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) { if (failed(verifyDimensionCompatibility( @@ -556,8 +557,8 @@ static LogicalResult verifyAllToAllOperandAndResultShape( static LogicalResult verifyScatterOrSliceOperandAndResultShape( Value operand, Value result, int64_t tensorAxis, ArrayRef meshAxes, ArrayRef meshShape) { - ShapedType operandType = operand.getType().cast(); - ShapedType resultType = result.getType().cast(); + ShapedType operandType = cast(operand.getType()); + ShapedType resultType = cast(result.getType()); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { if (axis != tensorAxis) { if (failed(verifyDimensionCompatibility( diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index 9acee5aa8d860..dbb9e667d4709 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -97,7 +97,7 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { FailureOr> mesh::getMeshShardingAttr(OpResult result) { - Value val = result.cast(); + Value val = cast(result); bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) { auto shardOp = llvm::dyn_cast(user); if (!shardOp) @@ -178,7 +178,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { return failure(); for (OpResult result : op->getResults()) { - auto resultType = result.getType().dyn_cast(); + auto resultType = dyn_cast(result.getType()); if (!resultType) return failure(); AffineMap map = maps[numOperands + result.getResultNumber()]; @@ -404,7 +404,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result, if (succeeded(maybeSharding) && !maybeSharding->first) return success(); - auto resultType = result.getType().cast(); + auto resultType = cast(result.getType()); SmallVector> splitAxes(resultType.getRank()); SmallVector partialAxes; @@ -457,7 +457,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, if (succeeded(maybeShardingAttr) && maybeShardingAttr->first) return success(); Value operand = opOperand.get(); - auto operandType = operand.getType().cast(); + auto operandType = cast(operand.getType()); SmallVector> splitAxes(operandType.getRank()); unsigned numDims = map.getNumDims(); for (auto it : llvm::enumerate(map.getResults())) { @@ -526,7 +526,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations( static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshShardingAttr sharding) { - if (value.getType().isa()) { + if (isa(value.getType())) { return sharding && isFullReplication(sharding); } diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index e4868435135ed..6b1326d76bc4a 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -86,14 +86,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder, } builder.setInsertionPointAfterValue(sourceShard); - TypedValue resultValue = + TypedValue resultValue = cast>( builder .create(sourceShard.getLoc(), sourceShard.getType(), sourceSharding.getMesh().getLeafReference(), allReduceMeshAxes, sourceShard, sourceSharding.getPartialType()) - .getResult() - .cast>(); + .getResult()); llvm::SmallVector remainingPartialAxes; llvm::copy_if(sourceShardingPartialAxesSet, @@ -135,13 +134,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { - TypedValue targetShard = + TypedValue targetShard = cast>( builder .create(sourceShard, mesh, ArrayRef(splitMeshAxis), splitTensorAxis) - .getResult() - .cast>(); + .getResult()); MeshShardingAttr targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); return {targetShard, targetSharding}; @@ -278,10 +276,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); - TypedValue targetShard = - builder.create(targetShape, allGatherResult) - .getResult() - .cast>(); + TypedValue targetShard = cast>( + builder.create(targetShape, allGatherResult).getResult()); return {targetShard, targetSharding}; } @@ -413,10 +409,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); - TypedValue targetShard = - builder.create(targetShape, allToAllResult) - .getResult() - .cast>(); + TypedValue targetShard = cast>( + builder.create(targetShape, allToAllResult).getResult()); return {targetShard, targetSharding}; } @@ -505,7 +499,7 @@ TypedValue reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); return reshard( implicitLocOpBuilder, mesh, source.getShard(), target.getShard(), - source.getSrc().cast>(), sourceShardValue); + cast>(source.getSrc()), sourceShardValue); } TypedValue reshard(OpBuilder &builder, ShardOp source, @@ -533,23 +527,22 @@ SmallVector shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection) { SmallVector res; - llvm::transform(block.getArguments(), std::back_inserter(res), - [&symbolTableCollection](BlockArgument arg) { - auto rankedTensorArg = - arg.dyn_cast>(); - if (!rankedTensorArg) { - return arg.getType(); - } - - assert(rankedTensorArg.hasOneUse()); - Operation *useOp = *rankedTensorArg.getUsers().begin(); - ShardOp shardOp = llvm::dyn_cast(useOp); - assert(shardOp); - MeshOp mesh = getMesh(shardOp, symbolTableCollection); - return shardShapedType(rankedTensorArg.getType(), mesh, - shardOp.getShardAttr()) - .cast(); - }); + llvm::transform( + block.getArguments(), std::back_inserter(res), + [&symbolTableCollection](BlockArgument arg) { + auto rankedTensorArg = dyn_cast>(arg); + if (!rankedTensorArg) { + return arg.getType(); + } + + assert(rankedTensorArg.hasOneUse()); + Operation *useOp = *rankedTensorArg.getUsers().begin(); + ShardOp shardOp = llvm::dyn_cast(useOp); + assert(shardOp); + MeshOp mesh = getMesh(shardOp, symbolTableCollection); + return cast(shardShapedType(rankedTensorArg.getType(), mesh, + shardOp.getShardAttr())); + }); return res; } @@ -587,7 +580,7 @@ static SmallVector getOperandShardings(Operation &op) { res.reserve(op.getNumOperands()); llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) { TypedValue rankedTensor = - operand.dyn_cast>(); + dyn_cast>(operand); if (!rankedTensor) { return MeshShardingAttr(); } @@ -608,7 +601,7 @@ static SmallVector getResultShardings(Operation &op) { llvm::transform(op.getResults(), std::back_inserter(res), [](OpResult result) { TypedValue rankedTensor = - result.dyn_cast>(); + dyn_cast>(result); if (!rankedTensor) { return MeshShardingAttr(); } @@ -636,9 +629,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, } else { // Insert resharding. assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers()); - TypedValue srcSpmdValue = - spmdizationMap.lookup(srcShardOp.getOperand()) - .cast>(); + TypedValue srcSpmdValue = cast>( + spmdizationMap.lookup(srcShardOp.getOperand())); targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, symbolTableCollection); } diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index cb13ee404751c..60c4e07a118cb 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -133,7 +133,7 @@ struct AllSliceOpLowering // insert tensor.extract_slice RankedTensorType operandType = - op.getOperand().getType().cast(); + cast(op.getOperand().getType()); SmallVector sizes; for (int64_t i = 0; i < operandType.getRank(); ++i) { if (i == sliceAxis) { @@ -202,10 +202,9 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, ImplicitLocOpBuilder &builder) { Operation::result_range meshShape = builder.create(mesh, axes).getResults(); - return arith::createProduct(builder, builder.getLoc(), - llvm::to_vector_of(meshShape), - builder.getIndexType()) - .cast>(); + return cast>(arith::createProduct( + builder, builder.getLoc(), llvm::to_vector_of(meshShape), + builder.getIndexType())); } TypedValue createProcessLinearIndex(StringRef mesh, diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 1635297a5447d..4e256aea0be37 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -651,7 +651,7 @@ struct MmaSyncBuilder { template static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn) { - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); auto vectorShape = vectorType.getShape(); auto strides = computeStrides(vectorShape); for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) { @@ -779,11 +779,11 @@ FailureOr MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get(); Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get(); Value resMemRef = linalgOp.getDpsInitOperand(0)->get(); - assert(lhsMemRef.getType().cast().getRank() == 2 && + assert(cast(lhsMemRef.getType()).getRank() == 2 && "expected lhs to be a 2D memref"); - assert(rhsMemRef.getType().cast().getRank() == 2 && + assert(cast(rhsMemRef.getType()).getRank() == 2 && "expected rhs to be a 2D memref"); - assert(resMemRef.getType().cast().getRank() == 2 && + assert(cast(resMemRef.getType()).getRank() == 2 && "expected res to be a 2D memref"); int64_t m = cast(lhsMemRef.getType()).getShape()[0]; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 1e480d6471cbc..f380926c4bce3 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1318,7 +1318,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) { for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) { Type varType = std::get<0>(privateVarInfo).getType(); SymbolRefAttr privatizerSym = - std::get<1>(privateVarInfo).template cast(); + cast(std::get<1>(privateVarInfo)); PrivateClauseOp privatizerOp = SymbolTable::lookupNearestSymbolFrom(op, privatizerSym); diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index 481275f052a3c..187b3b71e3458 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -145,9 +145,9 @@ void VarInfo::setNum(Var::Num n) { /// mismatches. LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { - const auto loc1 = parser.getEncodedSourceLoc(sm1).dyn_cast(); + const auto loc1 = dyn_cast(parser.getEncodedSourceLoc(sm1)); assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`"); - const auto loc2 = parser.getEncodedSourceLoc(sm2).dyn_cast(); + const auto loc2 = dyn_cast(parser.getEncodedSourceLoc(sm2)); assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`"); if (loc1.getFilename() != loc2.getFilename()) return SMLoc(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 516b0943bdcfa..b1d44559fa5ab 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2078,7 +2078,7 @@ struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (attr.isa()) { + if (isa(attr)) { os << "sparse"; return AliasResult::OverridableAlias; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp index 4f9988d48d771..9c84f4c25866f 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp @@ -29,7 +29,7 @@ LogicalResult sparse_tensor::detail::stageWithSortImpl( Location loc = op.getLoc(); Type finalTp = op->getOpResult(0).getType(); - SparseTensorType dstStt(finalTp.cast()); + SparseTensorType dstStt(cast(finalTp)); Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false); // Clones the original operation but changing the output to an unordered COO. diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp index 5b7ea9360e221..ca19259ebffa6 100644 --- a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp +++ b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp @@ -25,7 +25,7 @@ DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation( return emitSilenceableFailure(current->getLoc(), "operation has no sparse input or output"); } - results.set(getResult().cast(), state.getPayloadOps(getTarget())); + results.set(cast(getResult()), state.getPayloadOps(getTarget())); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp index eafbe95b7aebe..a53bce16dad86 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp @@ -42,7 +42,7 @@ static void convTypes(TypeRange types, SmallVectorImpl &convTypes, if (kind == SparseTensorFieldKind::PosMemRef || kind == SparseTensorFieldKind::CrdMemRef || kind == SparseTensorFieldKind::ValMemRef) { - auto rtp = t.cast(); + auto rtp = cast(t); if (!directOut) { rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); if (extraTypes) @@ -97,7 +97,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types, mem = builder.create(loc, inputs[0]); toVals.push_back(mem); } else { - ShapedType rtp = t.cast(); + ShapedType rtp = cast(t); rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); inputs.push_back(extraVals[extra++]); retTypes.push_back(rtp); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index 9c0fc60877d8a..36ecf692b02c5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -502,7 +502,7 @@ struct GenericOpScheduler : public OpRewritePattern { for (const AffineExpr l : order.getResults()) { unsigned loopId = llvm::cast(l).getPosition(); auto itTp = - linalgOp.getIteratorTypes()[loopId].cast(); + cast(linalgOp.getIteratorTypes()[loopId]); if (linalg::isReductionIterator(itTp.getValue())) break; // terminate at first reduction nest++; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index b117c1694e45b..02375f54d7152 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -476,8 +476,8 @@ struct GenSemiRingSelect : public OpRewritePattern { if (!sel) return std::nullopt; - auto tVal = sel.getTrueValue().dyn_cast(); - auto fVal = sel.getFalseValue().dyn_cast(); + auto tVal = dyn_cast(sel.getTrueValue()); + auto fVal = dyn_cast(sel.getFalseValue()); // TODO: For simplicity, we only handle cases where both true/false value // are directly loaded the input tensor. We can probably admit more cases // in theory. @@ -487,7 +487,7 @@ struct GenSemiRingSelect : public OpRewritePattern { // Helper lambda to determine whether the value is loaded from a dense input // or is a loop invariant. auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool { - if (auto bArg = v.dyn_cast(); + if (auto bArg = dyn_cast(v); bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber()))) return true; // If the value is defined outside the loop, it is a loop invariant. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp index 89af75dea2a0f..de553a5f9bf08 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp @@ -165,7 +165,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc, Value elem, Type dstTp) { - if (auto rtp = dstTp.dyn_cast()) { + if (auto rtp = dyn_cast(dstTp)) { // Scalars can only be converted to 0-ranked tensors. assert(rtp.getRank() == 0); elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index 66f96ba08c0ed..8981de58306da 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -157,8 +157,7 @@ IterationGraphSorter::IterationGraphSorter( // The number of results of the map should match the rank of the tensor. assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) { auto [m, v] = mvPair; - return m.getNumResults() == - v.getType().template cast().getRank(); + return m.getNumResults() == cast(v.getType()).getRank(); })); itGraph.resize(getNumLoops(), std::vector(getNumLoops(), false)); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 0ce40e8137120..80bc04d62bbe8 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -820,7 +820,7 @@ struct DimOfDestStyleOp : public OpRewritePattern { if (!destOp) return failure(); - auto resultIndex = source.cast().getResultNumber(); + auto resultIndex = cast(source).getResultNumber(); auto *initOperand = destOp.getDpsInitOperand(resultIndex); rewriter.modifyOpInPlace( @@ -3475,7 +3475,7 @@ SplatOp::reifyResultShapes(OpBuilder &builder, OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { auto constOperand = adaptor.getInput(); - if (!constOperand.isa_and_nonnull()) + if (!isa_and_nonnull(constOperand)) return {}; // Do not fold if the splat is not statically shaped @@ -4307,7 +4307,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, /// unpack(destinationStyleOp(x)) -> unpack(x) if (auto dstStyleOp = unPackOp.getDest().getDefiningOp()) { - auto destValue = unPackOp.getDest().cast(); + auto destValue = cast(unPackOp.getDest()); Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()]; rewriter.modifyOpInPlace(unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); }); diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp index 06a441dbeaf15..137156fe1a73e 100644 --- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp @@ -32,7 +32,7 @@ namespace { struct MatMulOpSharding : public ShardingInterface::ExternalModel { SmallVector getLoopIteratorTypes(Operation *op) const { - auto tensorType = op->getResult(0).getType().dyn_cast(); + auto tensorType = dyn_cast(op->getResult(0).getType()); if (!tensorType) return {}; @@ -48,7 +48,7 @@ struct MatMulOpSharding } SmallVector getIndexingMaps(Operation *op) const { - auto tensorType = op->getResult(0).getType().dyn_cast(); + auto tensorType = dyn_cast(op->getResult(0).getType()); if (!tensorType) return {}; MLIRContext *ctx = op->getContext(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c8bf4c526b239..c139d5f60024c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -285,7 +285,7 @@ struct ClampIsNoOp : public OpRewritePattern { return failure(); } - if (inputElementType.isa()) { + if (isa(inputElementType)) { // Unlike integer types, floating point types can represent infinity. auto minClamp = op.getMinFp(); auto maxClamp = op.getMaxFp(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index e06ac9a27ae4c..10e6016a1ed43 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -168,7 +168,7 @@ ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, return parser.emitError(parser.getCurrentLocation()) << "expected attribute"; } - if (auto typedAttr = attr.dyn_cast()) { + if (auto typedAttr = dyn_cast(attr)) { typeAttr = TypeAttr::get(typedAttr.getType()); } return success(); @@ -186,7 +186,7 @@ ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr) { bool needsSpace = false; - auto typedAttr = attr.dyn_cast_or_null(); + auto typedAttr = dyn_cast_or_null(attr); if (!typedAttr || typedAttr.getType() != type.getValue()) { p << ": "; p.printAttribute(type); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 6575b39fd45a1..6eef2c5018d6d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -371,7 +371,7 @@ struct ReduceConstantOptimization : public OpRewritePattern { auto reductionAxis = op.getAxis(); const auto denseElementsAttr = constOp.getValue(); const auto shapedOldElementsValues = - denseElementsAttr.getType().cast(); + cast(denseElementsAttr.getType()); if (!llvm::isa(shapedOldElementsValues.getElementType())) return rewriter.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 74ef6381f3d70..c99f62d5ae112 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -357,7 +357,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { bool levelCheckTransposeConv2d(Operation *op) { if (auto transpose = dyn_cast(op)) { if (ShapedType filterType = - transpose.getFilter().getType().dyn_cast()) { + dyn_cast(transpose.getFilter().getType())) { auto shape = filterType.getShape(); assert(shape.size() == 4); // level check kernel sizes for kH and KW diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp index 4b3e28e4313c6..94d4a96a07ad6 100644 --- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp @@ -21,16 +21,15 @@ DiagnosedSilenceableFailure transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - if (getAt().getType().isa()) { + if (isa(getAt().getType())) { auto payload = state.getPayloadOps(getAt()); for (Operation *op : payload) op->emitRemark() << getMessage(); return DiagnosedSilenceableFailure::success(); } - assert( - getAt().getType().isa() && - "unhandled kind of transform type"); + assert(isa(getAt().getType()) && + "unhandled kind of transform type"); auto describeValue = [](Diagnostic &os, Value value) { os << "value handle points to "; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 53f958caa0bdb..7a5a697470058 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1615,7 +1615,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, } params.push_back(TypeAttr::get(type)); } - results.setParams(getResult().cast(), params); + results.setParams(cast(getResult()), params); return DiagnosedSilenceableFailure::success(); } @@ -2217,14 +2217,14 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, llvm_unreachable("unknown kind of transform dialect type"); return 0; }); - results.setParams(getNum().cast(), + results.setParams(cast(getNum()), rewriter.getI64IntegerAttr(numAssociations)); return DiagnosedSilenceableFailure::success(); } LogicalResult transform::NumAssociationsOp::verify() { // Verify that the result type accepts an i64 attribute as payload. - auto resultType = getNum().getType().cast(); + auto resultType = cast(getNum().getType()); return resultType .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)}) .checkAndReport(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp index 8d9f105d1c5db..9a24c2baebabb 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -44,7 +44,7 @@ DiagnosedSilenceableFailure transform::AffineMapParamType::checkPayload(Location loc, ArrayRef payload) const { for (Attribute attr : payload) { - if (!attr.isa()) { + if (!mlir::isa(attr)) { return emitSilenceableError(loc) << "expected affine map attribute, got " << attr; } @@ -144,7 +144,7 @@ DiagnosedSilenceableFailure transform::TypeParamType::checkPayload(Location loc, ArrayRef payload) const { for (Attribute attr : payload) { - if (!attr.isa()) { + if (!mlir::isa(attr)) { return emitSilenceableError(loc) << "expected type attribute, got " << attr; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3e6425879cc67..d10a31941db4f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6169,7 +6169,7 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns( OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { auto constOperand = adaptor.getInput(); - if (!constOperand.isa_and_nonnull()) + if (!isa_and_nonnull(constOperand)) return {}; // SplatElementsAttr::get treats single value for second arg as being a splat. diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 0693aa596cb28..b30b43d70bf0f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -57,7 +57,7 @@ static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, Value broadcasted = extendVectorRank(builder, loc, vec, addedRank); SmallVector permutation; for (int64_t i = addedRank, - e = broadcasted.getType().cast().getRank(); + e = cast(broadcasted.getType()).getRank(); i < e; ++i) permutation.push_back(i); for (int64_t i = 0; i < addedRank; ++i) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 8d733c5a8849b..7ed3dea42b771 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -403,7 +403,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, // Such transposes do not materially effect the underlying vector and can // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32> bool transposeNonOuterUnitDims = false; - auto operandShape = operands[it.index()].getType().cast(); + auto operandShape = cast(operands[it.index()].getType()); for (auto [index, dim] : llvm::enumerate(ArrayRef(perm).drop_back(1))) { if (dim != static_cast(index) && diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index dc6f126aae4c8..d24721f3defa6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -63,7 +63,7 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, // new mask index) only happens on the last dimension of the vectors. Operation *newMask = nullptr; SmallVector shape( - maskOp->getResultTypes()[0].cast().getShape()); + cast(maskOp->getResultTypes()[0]).getShape()); shape.back() = numElements; auto newMaskType = VectorType::get(shape, rewriter.getI1Type()); if (createMaskOp) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index b844c2bfa837c..ee622e886f618 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -171,7 +171,7 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { /// is first inserted, followed by a `memref.cast`. static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType) { - MemRefType sourceType = memref.getType().cast(); + MemRefType sourceType = cast(memref.getType()); Value res = memref; if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) { sourceType = MemRefType::get( diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 530c50ef74f7a..23c5749c2309d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -127,7 +127,7 @@ LogicalResult CreateNdDescOp::verify() { // check source type matches the rank if it is a memref. // It also should have the same ElementType as TensorDesc. - auto memrefTy = getSourceType().dyn_cast(); + auto memrefTy = dyn_cast(getSourceType()); if (memrefTy) { invalidRank |= (memrefTy.getRank() != rank); invalidElemTy |= memrefTy.getElementType() != getElementType(); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 6cdc2682753fc..411ac656e4afb 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -711,7 +711,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map, for (int64_t i = 0; i < map.getNumDims(); ++i) { if (auto attr = operands[i].dyn_cast()) { dimReplacements.push_back( - b.getAffineConstantExpr(attr.cast().getInt())); + b.getAffineConstantExpr(cast(attr).getInt())); } else { dimReplacements.push_back(b.getAffineDimExpr(numDims++)); remainingValues.push_back(operands[i].get()); @@ -721,7 +721,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map, for (int64_t i = 0; i < map.getNumSymbols(); ++i) { if (auto attr = operands[i + map.getNumDims()].dyn_cast()) { symReplacements.push_back( - b.getAffineConstantExpr(attr.cast().getInt())); + b.getAffineConstantExpr(cast(attr).getInt())); } else { symReplacements.push_back(b.getAffineSymbolExpr(numSymbols++)); remainingValues.push_back(operands[i + map.getNumDims()].get()); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index db903d540761b..0feb078db297d 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1154,7 +1154,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) { // delegate function that returns rank of shaped type with known rank auto getRank = [](const Type type) { - return type.cast().getRank(); + return cast(type).getRank(); }; auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin()) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index e89ff9209b034..ebcdbc02aadd0 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2489,7 +2489,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder, auto addDevInfos = [&, fail](auto devOperands, auto devOpType) -> void { for (const auto &devOp : devOperands) { // TODO: Only LLVMPointerTypes are handled. - if (!devOp.getType().template isa()) + if (!isa(devOp.getType())) return fail(); llvm::Value *mapOpValue = moduleTranslation.lookupValue(devOp); @@ -3083,10 +3083,9 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, std::vector generatedRefs; std::vector targetTriple; - auto targetTripleAttr = - op->getParentOfType() - ->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName()) - .dyn_cast_or_null(); + auto targetTripleAttr = dyn_cast_or_null( + op->getParentOfType()->getAttr( + LLVM::LLVMDialect::getTargetTripleAttrName())); if (targetTripleAttr) targetTriple.emplace_back(targetTripleAttr.data()); @@ -3328,7 +3327,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( attribute.getName()) .Case("omp.is_target_device", [&](Attribute attr) { - if (auto deviceAttr = attr.dyn_cast()) { + if (auto deviceAttr = dyn_cast(attr)) { llvm::OpenMPIRBuilderConfig &config = moduleTranslation.getOpenMPBuilder()->Config; config.setIsTargetDevice(deviceAttr.getValue()); @@ -3338,7 +3337,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( }) .Case("omp.is_gpu", [&](Attribute attr) { - if (auto gpuAttr = attr.dyn_cast()) { + if (auto gpuAttr = dyn_cast(attr)) { llvm::OpenMPIRBuilderConfig &config = moduleTranslation.getOpenMPBuilder()->Config; config.setIsGPU(gpuAttr.getValue()); @@ -3348,7 +3347,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( }) .Case("omp.host_ir_filepath", [&](Attribute attr) { - if (auto filepathAttr = attr.dyn_cast()) { + if (auto filepathAttr = dyn_cast(attr)) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue()); @@ -3358,13 +3357,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( }) .Case("omp.flags", [&](Attribute attr) { - if (auto rtlAttr = attr.dyn_cast()) + if (auto rtlAttr = dyn_cast(attr)) return convertFlagsAttr(op, rtlAttr, moduleTranslation); return failure(); }) .Case("omp.version", [&](Attribute attr) { - if (auto versionAttr = attr.dyn_cast()) { + if (auto versionAttr = dyn_cast(attr)) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp", @@ -3376,15 +3375,14 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( .Case("omp.declare_target", [&](Attribute attr) { if (auto declareTargetAttr = - attr.dyn_cast()) + dyn_cast(attr)) return convertDeclareTargetAttr(op, declareTargetAttr, moduleTranslation); return failure(); }) .Case("omp.requires", [&](Attribute attr) { - if (auto requiresAttr = - attr.dyn_cast()) { + if (auto requiresAttr = dyn_cast(attr)) { using Requires = omp::ClauseRequires; Requires flags = requiresAttr.getValue(); llvm::OpenMPIRBuilderConfig &config = diff --git a/mlir/lib/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.cpp index 8212725b5a58b..b78b002d32292 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.cpp @@ -29,8 +29,8 @@ using mlir::LLVM::detail::createIntrinsicCall; /// option around. static llvm::Type *getXlenType(Attribute opcodeAttr, LLVM::ModuleTranslation &moduleTranslation) { - auto intAttr = opcodeAttr.cast(); - unsigned xlenWidth = intAttr.getType().cast().getWidth(); + auto intAttr = cast(opcodeAttr); + unsigned xlenWidth = cast(intAttr.getType()).getWidth(); return llvm::Type::getIntNTy(moduleTranslation.getLLVMContext(), xlenWidth); } diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp index c8bee817213d8..e17fe12b9088b 100644 --- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp +++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp @@ -25,7 +25,7 @@ namespace { /// according to LLVM's encoding: /// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html static std::pair legalizeVectorType(const Type &type) { - VectorType vt = type.cast(); + VectorType vt = cast(type); // To simplify test pass, avoid multi-dimensional vectors. if (!vt || vt.getRank() != 1) return {0, nullptr}; @@ -39,7 +39,7 @@ static std::pair legalizeVectorType(const Type &type) { sew = 32; else if (eltTy.isF64()) sew = 64; - else if (auto intTy = eltTy.dyn_cast()) + else if (auto intTy = dyn_cast(eltTy)) sew = intTy.getWidth(); else return {0, nullptr}; diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp index 9b3082a819224..5e3918f79d184 100644 --- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp @@ -67,12 +67,11 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern { ImplicitLocOpBuilder builder(op->getLoc(), rewriter); ShapedType sourceShardShape = shardShapedType(op.getResult().getType(), mesh, op.getShard()); - TypedValue sourceShard = + TypedValue sourceShard = cast>( builder .create(sourceShardShape, op.getOperand()) - ->getResult(0) - .cast>(); + ->getResult(0)); TypedValue targetShard = reshard(builder, mesh, op, targetShardOp, sourceShard); Value newTargetUnsharded = diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp index 2dd99c67c1439..fa093cafcb0dc 100644 --- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp +++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp @@ -61,7 +61,7 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation( } bool createSymbol = false; - if (auto boolAttr = attr.dyn_cast()) + if (auto boolAttr = dyn_cast(attr)) createSymbol = boolAttr.getValue(); if (createSymbol) { diff --git a/mlir/test/lib/IR/TestAffineWalk.cpp b/mlir/test/lib/IR/TestAffineWalk.cpp index 8361b48ce4285..e8b836888b459 100644 --- a/mlir/test/lib/IR/TestAffineWalk.cpp +++ b/mlir/test/lib/IR/TestAffineWalk.cpp @@ -44,7 +44,7 @@ void TestAffineWalk::runOnOperation() { // Test whether the walk is being correctly interrupted. m.walk([](Operation *op) { for (NamedAttribute attr : op->getAttrs()) { - auto mapAttr = attr.getValue().dyn_cast(); + auto mapAttr = dyn_cast(attr.getValue()); if (!mapAttr) return; checkMod(mapAttr.getAffineMap(), op->getLoc()); diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp index 71ed30bfbe34c..93c4bcfe1424e 100644 --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -51,7 +51,7 @@ struct TestElementsAttrInterface InFlightDiagnostic diag = op->emitError() << "Test iterating `" << type << "`: "; - if (!attr.getElementType().isa()) { + if (!isa(attr.getElementType())) { diag << "expected element type to be an integer type"; return; } diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index 56af3c15b905f..77aa30f847dcd 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -61,7 +61,7 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter, PDLResultList &results, ArrayRef args) { auto *op = args[0].cast(); - int numTypes = args[1].cast().cast().getInt(); + int numTypes = cast(args[1].cast()).getInt(); if (op->getName().getStringRef() == "test.success_op") { SmallVector types;