Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][NFC] Simplify type checks with isa predicates #87183

Merged
merged 1 commit into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter) {
TypeRange operandTypes(operands);
if (llvm::none_of(operandTypes,
[](Type type) { return isa<VectorType>(type); })) {
if (llvm::none_of(operandTypes, llvm::IsaPred<VectorType>)) {
return rewriter.notifyMatchFailure(op, "expected vector operand");
}
if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
Expand Down
10 changes: 3 additions & 7 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,7 @@ template <typename ExtOpTy>
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
return false;
return llvm::all_of(extOp->getUsers(), [](Operation *user) {
return isa<vector::ContractionOp>(user);
});
return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
}

static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
Expand Down Expand Up @@ -345,15 +343,13 @@ getSliceContract(Operation *op,
static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
bool useNvGpu) {
auto hasVectorDest = [](Operation *op) {
return llvm::any_of(op->getResultTypes(),
[](Type t) { return isa<VectorType>(t); });
return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
};
BackwardSliceOptions backwardSliceOptions;
backwardSliceOptions.filter = hasVectorDest;

auto hasVectorSrc = [](Operation *op) {
return llvm::any_of(op->getOperandTypes(),
[](Type t) { return isa<VectorType>(t); });
return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
};
ForwardSliceOptions forwardSliceOptions;
forwardSliceOptions.filter = hasVectorSrc;
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) {

bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) {
// Any memref-typed iteration arguments are treated as serializing.
if (llvm::any_of(forOp.getResultTypes(),
[](Type type) { return isa<BaseMemRefType>(type); }))
if (llvm::any_of(forOp.getResultTypes(), llvm::IsaPred<BaseMemRefType>))
return false;

// Collect all load and store ops in loop nest rooted at 'forOp'.
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,8 @@ makePattern(const DenseSet<Operation *> &parallelLoops, int vectorRank,
}

static NestedPattern &vectorTransferPattern() {
static auto pattern = affine::matcher::Op([](Operation &op) {
return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
});
static auto pattern = affine::matcher::Op(
llvm::IsaPred<vector::TransferReadOp, vector::TransferWriteOp>);
return pattern;
}

Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);

// Return common loop depth for loads if there are no store ops.
if (all_of(targetDstOps,
[&](Operation *op) { return isa<AffineReadOpInterface>(op); }))
if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
return loopDepth;

// Check dependences on all pairs of ops in 'targetDstOps' and store the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ struct FuncOpInterface
static bool supportsUnstructuredControlFlow() { return true; }

bool hasTensorSemantics(Operation *op) const {
auto isaTensor = [](Type type) { return isa<TensorType>(type); };
auto isaTensor = llvm::IsaPred<TensorType>;

// A function has tensor semantics if it has tensor arguments/results.
auto funcOp = cast<FuncOp>(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"

using namespace mlir;
Expand Down Expand Up @@ -277,9 +278,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,

/// Return "true" if the given function signature has tensor semantics.
static bool hasTensorSignature(func::FuncOp funcOp) {
auto isaTensor = [](Type t) { return isa<TensorType>(t); };
return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
return llvm::any_of(funcOp.getFunctionType().getInputs(),
llvm::IsaPred<TensorType>) ||
llvm::any_of(funcOp.getFunctionType().getResults(),
llvm::IsaPred<TensorType>);
}

/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ LogicalResult emitc::CallOpaqueOp::verify() {
}
}

if (llvm::any_of(getResultTypes(),
[](Type type) { return isa<ArrayType>(type); })) {
if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
return emitOpError() << "cannot return array type";
}

Expand Down
24 changes: 8 additions & 16 deletions mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,22 +296,14 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
"scf.forall op requires a mapping attribute");
}

bool hasBlockMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
return isa<GPUBlockMappingAttr>(attr);
});
bool hasWarpgroupMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
return isa<GPUWarpgroupMappingAttr>(attr);
});
bool hasWarpMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
return isa<GPUWarpMappingAttr>(attr);
});
bool hasThreadMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
return isa<GPUThreadMappingAttr>(attr);
});
bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
llvm::IsaPred<GPUBlockMappingAttr>);
bool hasWarpgroupMapping = llvm::any_of(
forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
llvm::IsaPred<GPUWarpMappingAttr>);
bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
llvm::IsaPred<GPUThreadMappingAttr>);
int64_t countMappingTypes = 0;
countMappingTypes += hasBlockMapping ? 1 : 0;
countMappingTypes += hasWarpgroupMapping ? 1 : 0;
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,8 @@ struct GpuAsyncRegionPass::DeferWaitCallback {
// control flow code.
static bool areAllUsersExecuteOrAwait(Value token) {
return !token.use_empty() &&
llvm::all_of(token.getUsers(), [](Operation *user) {
return isa<async::ExecuteOp, async::AwaitOp>(user);
});
llvm::all_of(token.getUsers(),
llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
}

// Add the `asyncToken` as dependency as needed after `op`.
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2786,10 +2786,8 @@ LogicalResult LLVM::BitcastOp::verify() {
if (!resultType)
return success();

auto isVector = [](Type type) {
return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
type);
};
auto isVector =
llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;

// Due to bitcast requiring both operands to be of the same size, it is not
// possible for only one of the two to be a pointer of vectors.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
Expand Down Expand Up @@ -119,8 +120,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) {
assert(llvm::all_of(outputTypes,
[](Type t) { return llvm::isa<ShapedType>(t); }));
assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));

SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
Expand Down Expand Up @@ -162,7 +162,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
resultTensorTypes.value_or(TypeRange());
if (!resultTensorTypes)
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
[](Type type) { return llvm::isa<RankedTensorType>(type); });
llvm::IsaPred<RankedTensorType>);

state.addOperands(inputs);
state.addOperands(outputs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {

// TODO: The conversion pattern can be made to work for `any_of` here, but
// it's more complex as it requires tracking which operands are scalars.
return llvm::all_of(op->getOperandTypes(),
[](Type type) { return isa<RankedTensorType>(type); });
return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
}

/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3537,15 +3537,14 @@ struct Conv1DGenerator
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
// must be block arguments or extension of block arguments.
bool setOperKind(Operation *reduceOp) {
int numBlockArguments = llvm::count_if(
reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
int numBlockArguments =
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
// Otherwise, if it can be pooling.
auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
return !isa<BlockArgument>(v);
});
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
if (isCastOfBlockArgument(feedOp)) {
oper = Pool;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
}

static bool isComputeOperation(Operation *op) {
return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
return isa<acc::ParallelOp, acc::LoopOp>(op);
}

namespace {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
if (getMatrixOperands()) {
Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
typeC.getElementType()};
if (!llvm::all_of(elementTypes,
[](Type ty) { return isa<IntegerType>(ty); })) {
if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
return emitOpError("Matrix Operands require all matrix element types to "
"be Integer Types");
}
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ LogicalResult shape::getShapeVec(Value input,
}

static bool isErrorPropagationPossible(TypeRange operandTypes) {
return llvm::any_of(operandTypes, [](Type ty) {
return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
});
return llvm::any_of(operandTypes,
llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
}

static LogicalResult verifySizeOrIndexOp(Operation *op) {
Expand Down
15 changes: 7 additions & 8 deletions mlir/lib/Dialect/Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
/// Returns a tuple corresponding to whether range has tensor or vector type.
template <typename iterator_range>
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
return std::make_tuple(
llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
return {llvm::any_of(types, llvm::IsaPred<TensorType>),
llvm::any_of(types, llvm::IsaPred<VectorType>)};
}

static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
Expand All @@ -202,7 +201,7 @@ static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
};
if (inferred.size() != existing.size())
return false;
for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
for (auto [inferredDim, existingDim] : llvm::zip_equal(inferred, existing))
if (!isCompatible(inferredDim, existingDim))
return false;
return true;
Expand Down Expand Up @@ -238,8 +237,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
std::get<1>(resultsHasTensorVectorType)))
return op->emitError("cannot broadcast vector with tensor");

auto rankedOperands = make_filter_range(
op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
auto rankedOperands =
make_filter_range(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);

// If all operands are unranked, then all result shapes are possible.
if (rankedOperands.empty())
Expand All @@ -257,8 +256,8 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
return op->emitOpError("operands don't have broadcast-compatible shapes");
}

auto rankedResults = make_filter_range(
op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
auto rankedResults =
make_filter_range(op->getResultTypes(), llvm::IsaPred<RankedTensorType>);

// If all of the results are unranked then no further verification.
if (rankedResults.empty())
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
assert(outputs.size() == 1 && "expected one output");
return llvm::all_of(
std::initializer_list<Type>{inputs.front(), outputs.front()},
[](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
llvm::IsaPred<transform::TransformHandleTypeInterface>);
}

//===----------------------------------------------------------------------===//
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,13 +898,12 @@ static LogicalResult verifyOutputShape(

AffineMap resMap = op.getIndexingMapsArray()[2];
auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
/*symCount=*/0, extents, ctx);
/*symbolCount=*/0, extents, ctx);
// Compose the resMap with the extentsMap, which is a constant map.
AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
assert(
llvm::all_of(expectedMap.getResults(),
[](AffineExpr e) { return isa<AffineConstantExpr>(e); }) &&
"expected constant extent along all dimensions.");
assert(llvm::all_of(expectedMap.getResults(),
llvm::IsaPred<AffineConstantExpr>) &&
"expected constant extent along all dimensions.");
// Extract the expected shape and build the type.
auto expectedShape = llvm::to_vector<4>(
llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
Expand Down
Loading
Loading