Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][NFC] Simplify type checks with isa predicates #87183

Merged
merged 1 commit into from
Apr 1, 2024

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Mar 31, 2024

For more context on isa predicates, see: #83753.

For more context on isa predicates, see: llvm#83753.
@llvmbot
Copy link
Member

llvmbot commented Mar 31, 2024

@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-shape
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-llvm

Author: Jakub Kuderski (kuhar)

Changes

For more context on isa predicates, see: #83753.


Patch is 31.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87183.diff

28 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+1-2)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-7)
  • (modified) mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+2-3)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+5-3)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+1-2)
  • (modified) mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (+8-16)
  • (modified) mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp (+2-3)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+4-5)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+2-3)
  • (modified) mlir/lib/Dialect/Traits.cpp (+7-8)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+19-22)
  • (modified) mlir/lib/IR/AffineMap.cpp (+1-3)
  • (modified) mlir/lib/IR/Operation.cpp (+1-3)
  • (modified) mlir/lib/TableGen/Class.cpp (+1-3)
  • (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+2-3)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+4-4)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+3-3)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+1-2)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 73d418cb841327..993c09b03c0fde 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -545,8 +545,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
                                       ConversionPatternRewriter &rewriter,
                                       const LLVMTypeConverter &converter) {
   TypeRange operandTypes(operands);
-  if (llvm::none_of(operandTypes,
-                    [](Type type) { return isa<VectorType>(type); })) {
+  if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
     return rewriter.notifyMatchFailure(op, "expected vector operand");
   }
   if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 85fb8a539912f7..399c0450824ee5 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -202,9 +202,7 @@ template <typename ExtOpTy>
 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
   if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
     return false;
-  return llvm::all_of(extOp->getUsers(), [](Operation *user) {
-    return isa<vector::ContractionOp>(user);
-  });
+  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
 }
 
 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
@@ -345,15 +343,13 @@ getSliceContract(Operation *op,
 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
                                              bool useNvGpu) {
   auto hasVectorDest = [](Operation *op) {
-    return llvm::any_of(op->getResultTypes(),
-                        [](Type t) { return isa<VectorType>(t); });
+    return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
   };
   BackwardSliceOptions backwardSliceOptions;
   backwardSliceOptions.filter = hasVectorDest;
 
   auto hasVectorSrc = [](Operation *op) {
-    return llvm::any_of(op->getOperandTypes(),
-                        [](Type t) { return isa<VectorType>(t); });
+    return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
   };
   ForwardSliceOptions forwardSliceOptions;
   forwardSliceOptions.filter = hasVectorSrc;
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 61244921bc38ac..69b3d41e17c2d4 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -136,8 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) {
 
 bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) {
   // Any memref-typed iteration arguments are treated as serializing.
-  if (llvm::any_of(forOp.getResultTypes(),
-                   [](Type type) { return isa<BaseMemRefType>(type); }))
+  if (llvm::any_of(forOp.getResultTypes(), llvm::IsaPred<BaseMemRefType>))
     return false;
 
   // Collect all load and store ops in loop nest rooted at 'forOp'.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 46c7871f40232f..71e9648a5e00fa 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -609,9 +609,8 @@ makePattern(const DenseSet<Operation *> &parallelLoops, int vectorRank,
 }
 
 static NestedPattern &vectorTransferPattern() {
-  static auto pattern = affine::matcher::Op([](Operation &op) {
-    return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
-  });
+  static auto pattern = affine::matcher::Op(
+      llvm::IsaPred<vector::TransferReadOp, vector::TransferWriteOp>);
   return pattern;
 }
 
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index fb45528ad5e7d1..84ae4b52dcf4e8 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -211,8 +211,7 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
   unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
 
   // Return common loop depth for loads if there are no store ops.
-  if (all_of(targetDstOps,
-             [&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
+  if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
     return loopDepth;
 
   // Check dependences on all pairs of ops in 'targetDstOps' and store the
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 4cdbbf35dc876b..053ea7935260a2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -326,7 +326,7 @@ struct FuncOpInterface
   static bool supportsUnstructuredControlFlow() { return true; }
 
   bool hasTensorSemantics(Operation *op) const {
-    auto isaTensor = [](Type type) { return isa<TensorType>(type); };
+    auto isaTensor = llvm::IsaPred<TensorType>;
 
     // A function has tensor semantics if it has tensor arguments/results.
     auto funcOp = cast<FuncOp>(op);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 33feea0b956ca0..0a4072605c265f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -67,6 +67,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 
 using namespace mlir;
@@ -277,9 +278,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
 
 /// Return "true" if the given function signature has tensor semantics.
 static bool hasTensorSignature(func::FuncOp funcOp) {
-  auto isaTensor = [](Type t) { return isa<TensorType>(t); };
-  return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
-         llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
+  return llvm::any_of(funcOp.getFunctionType().getInputs(),
+                      llvm::IsaPred<TensorType>) ||
+         llvm::any_of(funcOp.getFunctionType().getResults(),
+                      llvm::IsaPred<TensorType>);
 }
 
 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..f4a9dc3ca509c8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -224,8 +224,7 @@ LogicalResult emitc::CallOpaqueOp::verify() {
     }
   }
 
-  if (llvm::any_of(getResultTypes(),
-                   [](Type type) { return isa<ArrayType>(type); })) {
+  if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
     return emitOpError() << "cannot return array type";
   }
 
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index fc3a4375694588..b584f63f16e0aa 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -296,22 +296,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
                                  "scf.forall op requires a mapping attribute");
   }
 
-  bool hasBlockMapping =
-      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUBlockMappingAttr>(attr);
-      });
-  bool hasWarpgroupMapping =
-      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUWarpgroupMappingAttr>(attr);
-      });
-  bool hasWarpMapping =
-      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUWarpMappingAttr>(attr);
-      });
-  bool hasThreadMapping =
-      llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
-        return isa<GPUThreadMappingAttr>(attr);
-      });
+  bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
+                                      llvm::IsaPred<GPUBlockMappingAttr>);
+  bool hasWarpgroupMapping = llvm::any_of(
+      forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
+  bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
+                                     llvm::IsaPred<GPUWarpMappingAttr>);
+  bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
+                                       llvm::IsaPred<GPUThreadMappingAttr>);
   int64_t countMappingTypes = 0;
   countMappingTypes += hasBlockMapping ? 1 : 0;
   countMappingTypes += hasWarpgroupMapping ? 1 : 0;
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index 40903f199afddd..b2fa3a99c53fc3 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -232,9 +232,8 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
   // control flow code.
   static bool areAllUsersExecuteOrAwait(Value token) {
     return !token.use_empty() &&
-           llvm::all_of(token.getUsers(), [](Operation *user) {
-             return isa<async::ExecuteOp, async::AwaitOp>(user);
-           });
+           llvm::all_of(token.getUsers(),
+                        llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
   }
 
   // Add the `asyncToken` as dependency as needed after `op`.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3ba6ac6ccc8142..e5c19a916392e1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2786,10 +2786,8 @@ LogicalResult LLVM::BitcastOp::verify() {
   if (!resultType)
     return success();
 
-  auto isVector = [](Type type) {
-    return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
-        type);
-  };
+  auto isVector =
+      llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
 
   // Due to bitcast requiring both operands to be of the same size, it is not
   // possible for only one of the two to be a pointer of vectors.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6954eee93efd14..2d7219fef87c64 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
@@ -119,8 +120,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                                    TypeRange inputTypes, TypeRange outputTypes,
                                    ArrayRef<NamedAttribute> attrs,
                                    RegionBuilderFn regionBuilder) {
-  assert(llvm::all_of(outputTypes,
-                      [](Type t) { return llvm::isa<ShapedType>(t); }));
+  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
 
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
@@ -162,7 +162,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
       resultTensorTypes.value_or(TypeRange());
   if (!resultTensorTypes)
     copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
-            [](Type type) { return llvm::isa<RankedTensorType>(type); });
+            llvm::IsaPred<RankedTensorType>);
 
   state.addOperands(inputs);
   state.addOperands(outputs);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 5508aaf9d87537..28d6752fc2d388 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -27,8 +27,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
 
   // TODO: The conversion pattern can be made to work for `any_of` here, but
   // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(),
-                      [](Type type) { return isa<RankedTensorType>(type); });
+  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c74ab1e6448bec..25785653a71675 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3537,15 +3537,14 @@ struct Conv1DGenerator
   // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
   // must be block arguments or extension of block arguments.
   bool setOperKind(Operation *reduceOp) {
-    int numBlockArguments = llvm::count_if(
-        reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
+    int numBlockArguments =
+        llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
     switch (numBlockArguments) {
     case 1: {
       // Will be convolution if feeder is a MulOp.
       // Otherwise, if it can be pooling.
-      auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
-        return !isa<BlockArgument>(v);
-      });
+      auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
+                                         llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
       if (isCastOfBlockArgument(feedOp)) {
         oper = Pool;
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c09a3403f9a3e3..9ba96e4be7d1fc 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -457,7 +457,7 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
 }
 
 static bool isComputeOperation(Operation *op) {
-  return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
+  return isa<acc::ParallelOp, acc::LoopOp>(op);
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d532d466334a56..2ff3efdc96a7f8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -125,8 +125,7 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
   if (getMatrixOperands()) {
     Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
                            typeC.getElementType()};
-    if (!llvm::all_of(elementTypes,
-                      [](Type ty) { return isa<IntegerType>(ty); })) {
+    if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
       return emitOpError("Matrix Operands require all matrix element types to "
                          "be Integer Types");
     }
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index f5a3717f815de5..58c3f4c334577c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -65,9 +65,8 @@ LogicalResult shape::getShapeVec(Value input,
 }
 
 static bool isErrorPropagationPossible(TypeRange operandTypes) {
-  return llvm::any_of(operandTypes, [](Type ty) {
-    return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
-  });
+  return llvm::any_of(operandTypes,
+                      llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
 }
 
 static LogicalResult verifySizeOrIndexOp(Operation *op) {
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index d4e0f8a3137053..2efc157ce79617 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -188,9 +188,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
 /// Returns a tuple corresponding to whether range has tensor or vector type.
 template <typename iterator_range>
 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
-  return std::make_tuple(
-      llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
-      llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
+  return {llvm::any_of(types, llvm::IsaPred<TensorType>),
+          llvm::any_of(types, llvm::IsaPred<VectorType>)};
 }
 
 static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
@@ -202,7 +201,7 @@ static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
   };
   if (inferred.size() != existing.size())
     return false;
-  for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
+  for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
     if (!isCompatible(inferredDim, existingDim))
       return false;
   return true;
@@ -238,8 +237,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
        std::get<1>(resultsHasTensorVectorType)))
     return op->emitError("cannot broadcast vector with tensor");
 
-  auto rankedOperands = make_filter_range(
-      op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
+  auto rankedOperands =
+      make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
 
   // If all operands are unranked, then all result shapes are possible.
   if (rankedOperands.empty())
@@ -257,8 +256,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
       return op->emitOpError("operands don't have broadcast-compatible shapes");
   }
 
-  auto rankedResults = make_filter_range(
-      op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
+  auto rankedResults =
+      make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);
 
   // If all of the results are unranked then no further verification.
   if (rankedResults.empty())
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 578b2492bbab46..c8d06ba157b904 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -819,7 +819,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   assert(outputs.size() == 1 && "expected one output");
   return llvm::all_of(
       std::initializer_list<Type>{inputs.front(), outputs.front()},
-      [](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
+      llvm::IsaPred<transform::TransformHandleTypeInterface>);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e566bfacf37984..3e6425879cc67f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -898,13 +898,12 @@ static LogicalResult verifyOutputShape(
 
     AffineMap resMap = op.getIndexingMapsArray()[2];
     auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
-                                     /*symCount=*/0, extents, ctx);
+                                     /*symbolCount=*/0, extents, ctx);
     // Compose the resMap with the extentsMap, which is a constant map.
     AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
-    assert(
-        llvm::all_of(expectedMap.getResults(),
-                     [](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
-        "expected constant extent along all dimensions.");
+    assert(llvm::all_of(expectedMap.getResults(),
+                        llvm::IsaPred<AffineConstantExpr>) &&
+           "expected constant extent along all dimensions.");
     // Extract the expe...
[truncated]

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be great if there was a clang-tidy check for this!

@kuhar kuhar merged commit 971b852 into llvm:main Apr 1, 2024
19 checks passed
mgehre-amd pushed a commit to Xilinx/llvm-project that referenced this pull request Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants