From acb69f3b7c83f411c08b77d75f2e812faf3cb83f Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 27 Nov 2020 21:09:13 +0100 Subject: [PATCH] [mlir] Change ConvertOpToLLVMPattern::matchAndRewrite argument to concrete operand type. Reviewed By: herhut, ftynse Differential Revision: https://reviews.llvm.org/D92111 --- .../StandardToLLVM/ConvertStandardToLLVM.h | 41 ++- .../ConvertLaunchFuncToRuntimeCalls.cpp | 67 ++-- .../ConvertLaunchFuncToLLVMCalls.cpp | 4 +- .../StandardToLLVM/StandardToLLVM.cpp | 310 ++++++++---------- .../test/lib/Transforms/TestConvertCallOp.cpp | 2 +- 5 files changed, 218 insertions(+), 206 deletions(-) diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 919a93ac84a274..70db4c1510bf80 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -564,14 +564,47 @@ class ConvertToLLVMPattern : public ConversionPattern { /// Utility class for operation conversions targeting the LLVM dialect that /// match exactly one source operation. -template +template class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) - : ConvertToLLVMPattern(OpTy::getOperationName(), + : ConvertToLLVMPattern(SourceOp::getOperationName(), &typeConverter.getContext(), typeConverter, benefit) {} + + /// Wrappers around the RewritePattern methods that pass the derived op type. + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast(op), operands, rewriter); + } + LogicalResult match(Operation *op) const final { + return match(cast(op)); + } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast(op), operands, rewriter); + } + + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual void rewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("must override rewrite or matchAndRewrite"); + } + virtual LogicalResult match(SourceOp op) const { + llvm_unreachable("must override match or matchAndRewrite"); + } + virtual LogicalResult + matchAndRewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (succeeded(match(op))) { + rewrite(op, operands, rewriter); + return success(); + } + return failure(); + } }; namespace LLVM { @@ -604,7 +637,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern { /// Converts the type of the result to an LLVM type, pass operands as is, /// preserve attributes. LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), operands, this->typeConverter, @@ -621,7 +654,7 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { using Super = VectorConvertToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { static_assert( std::is_base_of, SourceOp>::value, diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp index d625db95e976af..cb7644cb7202d8 100644 --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -163,7 +163,7 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern private: LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -205,7 +205,7 @@ class ConvertWaitOpToGpuRuntimeCallPattern private: LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::WaitOp waitOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -219,7 +219,7 @@ class ConvertWaitAsyncOpToGpuRuntimeCallPattern private: LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::WaitOp waitOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -251,7 +251,7 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern Location loc, OpBuilder &builder) const; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; llvm::SmallString<32> gpuBinaryAnnotation; @@ -321,14 +321,15 @@ isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, } LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( - Operation *op, ArrayRef operands, + gpu::HostRegisterOp hostRegisterOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { + auto *op = hostRegisterOp.getOperation(); if (failed(areAllLLVMTypes(op, operands, rewriter))) return failure(); Location loc = op->getLoc(); - auto memRefType = cast(op).value().getType(); + auto memRefType = hostRegisterOp.value().getType(); auto elementType = memRefType.cast().getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); @@ -412,19 +413,19 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( // afterwards. In case this isn't correct, we will get a runtime error. // Eventually, we will have a pass that guarantees this property. LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( - Operation *op, ArrayRef operands, + gpu::WaitOp waitOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (cast(op).asyncToken()) - return rewriter.notifyMatchFailure(op, "Cannot convert async op."); + if (waitOp.asyncToken()) + return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); - Location loc = op->getLoc(); + Location loc = waitOp.getLoc(); for (auto asyncDependency : operands) streamSynchronizeCallBuilder.create(loc, rewriter, {asyncDependency}); for (auto asyncDependency : operands) streamDestroyCallBuilder.create(loc, rewriter, {asyncDependency}); - rewriter.eraseOp(op); + rewriter.eraseOp(waitOp); return success(); } @@ -435,23 +436,23 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( // assumes that there is no other use between the definition and this op, and // the plan is to have a pass that guarantees this property. LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( - Operation *op, ArrayRef operands, + gpu::WaitOp waitOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (!cast(op).asyncToken()) - return rewriter.notifyMatchFailure(op, "Can only convert async op."); + if (!waitOp.asyncToken()) + return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); - Location loc = op->getLoc(); + Location loc = waitOp.getLoc(); auto insertionPoint = rewriter.saveInsertionPoint(); SmallVector events; - for (auto pair : llvm::zip(op->getOperands(), operands)) { + for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) { auto token = std::get<0>(pair); if (auto *defOp = token.getDefiningOp()) { rewriter.setInsertionPointAfter(defOp); } else { // If we can't find the defining op, we record the event at block start, // which is late and therefore misses parallelism, but still valid. - rewriter.setInsertionPointToStart(op->getBlock()); + rewriter.setInsertionPointToStart(waitOp.getOperation()->getBlock()); } auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); auto stream = std::get<1>(pair); @@ -464,7 +465,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); for (auto event : events) eventDestroyCallBuilder.create(loc, rewriter, {event}); - rewriter.replaceOp(op, {stream}); + rewriter.replaceOp(waitOp, {stream}); return success(); } @@ -564,23 +565,21 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( // If the op is async, the stream corresponds to the (single) async dependency // as well as the async token the op produces. LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( - Operation *op, ArrayRef operands, + gpu::LaunchFuncOp launchOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (failed(areAllLLVMTypes(op, operands, rewriter))) + if (failed(areAllLLVMTypes(launchOp, operands, rewriter))) return failure(); - auto launchOp = cast(op); - if (launchOp.asyncDependencies().size() > 1) return rewriter.notifyMatchFailure( - op, "Cannot convert with more than one async dependency."); + launchOp, "Cannot convert with more than one async dependency."); // Fail when the synchronous version of the op has async dependencies. The // lowering destroys the stream, and we do not want to check that there is no // use of the stream after this op. if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty()) return rewriter.notifyMatchFailure( - op, "Cannot convert non-async op with async dependencies."); + launchOp, "Cannot convert non-async op with async dependencies."); Location loc = launchOp.getLoc(); @@ -612,7 +611,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( loc, rewriter, {module.getResult(0), kernelName}); auto zero = rewriter.create(loc, llvmInt32Type, rewriter.getI32IntegerAttr(0)); - auto adaptor = gpu::LaunchFuncOpAdaptor(operands, op->getAttrDictionary()); + auto adaptor = gpu::LaunchFuncOpAdaptor( + operands, launchOp.getOperation()->getAttrDictionary()); Value stream = adaptor.asyncDependencies().empty() ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0) @@ -620,23 +620,24 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( // Create array of pointers to kernel arguments. auto kernelParams = generateParamsArray(launchOp, operands, rewriter); auto nullpointer = rewriter.create(loc, llvmPointerPointerType); - launchKernelCallBuilder.create( - loc, rewriter, - {function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(), - launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(), - launchOp.blockSizeZ(), /*sharedMemBytes=*/zero, stream, kernelParams, - /*extra=*/nullpointer}); + launchKernelCallBuilder.create(loc, rewriter, + {function.getResult(0), launchOp.gridSizeX(), + launchOp.gridSizeY(), launchOp.gridSizeZ(), + launchOp.blockSizeX(), launchOp.blockSizeY(), + launchOp.blockSizeZ(), + /*sharedMemBytes=*/zero, stream, kernelParams, + /*extra=*/nullpointer}); if (launchOp.asyncToken()) { // Async launch: make dependent ops use the same stream. - rewriter.replaceOp(op, {stream}); + rewriter.replaceOp(launchOp, {stream}); } else { // Synchronize with host and destroy stream. This must be the stream created // above (with no other uses) because we check that the synchronous version // does not have any async dependencies. streamSynchronizeCallBuilder.create(loc, rewriter, stream); streamDestroyCallBuilder.create(loc, rewriter, stream); - rewriter.eraseOp(op); + rewriter.eraseOp(launchOp); } moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0)); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index c34198e48d6f19..525a5be2448577 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -151,9 +151,9 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - gpu::LaunchFuncOp launchOp = cast(op); + auto *op = launchOp.getOperation(); MLIRContext *context = rewriter.getContext(); auto module = launchOp.getParentOfType(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 49942995fc78cd..c19f53c4e99965 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1396,10 +1396,8 @@ struct FuncOpConversion : public FuncOpConversionBase { : FuncOpConversionBase(converter) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto funcOp = cast(op); - auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); @@ -1407,14 +1405,14 @@ struct FuncOpConversion : public FuncOpConversionBase { if (typeConverter.getOptions().emitCWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) - wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, + wrapExternalFunction(rewriter, funcOp.getLoc(), typeConverter, funcOp, newFuncOp); else - wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp, + wrapForExternalCallers(rewriter, funcOp.getLoc(), typeConverter, funcOp, newFuncOp); } - rewriter.eraseOp(op); + rewriter.eraseOp(funcOp); return success(); } }; @@ -1425,10 +1423,8 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase { using FuncOpConversionBase::FuncOpConversionBase; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto funcOp = cast(op); - // Store the type of memref-typed arguments before the conversion so that we // can promote them to MemRef descriptor at the beginning of the function. SmallVector oldArgTypes = @@ -1438,7 +1434,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase { if (!newFuncOp) return failure(); if (newFuncOp.getBody().empty()) { - rewriter.eraseOp(op); + rewriter.eraseOp(funcOp); return success(); } @@ -1471,7 +1467,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase { // TODO: The placeholder is needed to avoid replacing barePtr uses in the // MemRef descriptor instructions. We may want to have a utility in the // rewriter to properly handle this use case. - Location loc = op->getLoc(); + Location loc = funcOp.getLoc(); auto placeholder = rewriter.create(loc, memrefTy); rewriter.replaceUsesOfBlockArgument(arg, placeholder); @@ -1480,7 +1476,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase { rewriter.replaceOp(placeholder, {desc}); } - rewriter.eraseOp(op); + rewriter.eraseOp(funcOp); return success(); } }; @@ -1711,13 +1707,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(AssertOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); + auto loc = op.getLoc(); AssertOp::Adaptor transformed(operands); // Insert the `abort` declaration if necessary. - auto module = op->getParentOfType(); + auto module = op.getParentOfType(); auto abortFunc = module.lookupSymbol("abort"); if (!abortFunc) { OpBuilder::InsertionGuard guard(rewriter); @@ -1754,13 +1750,13 @@ struct CreateComplexOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(CreateComplexOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto complexOp = cast(op); CreateComplexOp::Adaptor transformed(operands); // Pack real and imaginary part in a complex number struct. - auto loc = op->getLoc(); + auto loc = op.getLoc(); auto structType = typeConverter.convertType(complexOp.getType()); auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); complexStruct.setReal(rewriter, loc, transformed.real()); @@ -1775,13 +1771,13 @@ struct ReOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(ReOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ReOp::Adaptor transformed(operands); // Extract real part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); - Value real = complexStruct.real(rewriter, op->getLoc()); + Value real = complexStruct.real(rewriter, op.getLoc()); rewriter.replaceOp(op, real); return success(); @@ -1792,13 +1788,13 @@ struct ImOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(ImOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ImOp::Adaptor transformed(operands); // Extract imaginary part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); - Value imaginary = complexStruct.imaginary(rewriter, op->getLoc()); + Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); rewriter.replaceOp(op, imaginary); return success(); @@ -1833,9 +1829,8 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *operation, ArrayRef operands, + matchAndRewrite(AddCFOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = cast(operation); auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, operands, rewriter); @@ -1861,9 +1856,8 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *operation, ArrayRef operands, + matchAndRewrite(SubCFOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = cast(operation); auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, operands, rewriter); @@ -1889,9 +1883,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *operation, ArrayRef operands, + matchAndRewrite(ConstantOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = cast(operation); // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast()) { auto type = typeConverter.convertType(op.getResult().getType()) @@ -2284,10 +2277,9 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { using Base = ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(CallOpType callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { typename CallOpType::Adaptor transformed(operands); - auto callOp = cast(op); // Pack the result types into a struct. Type packedResult = nullptr; @@ -2301,10 +2293,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { } auto promoted = this->typeConverter.promoteOperands( - op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter); + callOp.getLoc(), /*opOperands=*/callOp.getOperation()->getOperands(), + operands, rewriter); auto newOp = rewriter.create( - op->getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), - promoted, op->getAttrs()); + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promoted, callOp.getAttrs()); SmallVector results; if (numResults < 2) { @@ -2315,9 +2308,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { // Extract individual results from the structure and return them as list. results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - auto type = this->typeConverter.convertType(op->getResult(i).getType()); + auto type = + this->typeConverter.convertType(callOp.getResult(i).getType()); results.push_back(rewriter.create( - op->getLoc(), type, newOp.getOperation()->getResult(0), + callOp.getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getI64ArrayAttr(i))); } } @@ -2327,16 +2321,16 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { // descriptors. assert(results.size() == resultTypes.size() && "The number of arguments and types doesn't match"); - this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(), - resultTypes, results); - } else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(), + this->typeConverter.promoteBarePtrsToDescriptors( + rewriter, callOp.getLoc(), resultTypes, results); + } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(), this->typeConverter, resultTypes, results, /*toDynamic=*/false))) { return failure(); } - rewriter.replaceOp(op, results); + rewriter.replaceOp(callOp, results); return success(); } }; @@ -2359,18 +2353,18 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(converter) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(DeallocOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); DeallocOp::Adaptor transformed(operands); // Insert the `free` declaration if it is not already present. auto freeFunc = - op->getParentOfType().lookupSymbol("free"); + op.getParentOfType().lookupSymbol("free"); if (!freeFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart( - op->getParentOfType().getBody()); + op.getParentOfType().getBody()); freeFunc = rewriter.create( rewriter.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), @@ -2379,8 +2373,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { MemRefDescriptor memref(transformed.memref()); Value casted = rewriter.create( - op->getLoc(), getVoidPtrType(), - memref.allocatedPtr(rewriter, op->getLoc())); + op.getLoc(), getVoidPtrType(), + memref.allocatedPtr(rewriter, op.getLoc())); rewriter.replaceOpWithNewOp( op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); return success(); @@ -2410,9 +2404,8 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(GlobalMemrefOp global, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto global = cast(op); MemRefType type = global.type().cast(); if (!isSupportedMemRefType(type)) return failure(); @@ -2434,7 +2427,7 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { } rewriter.replaceOpWithNewOp( - op, arrayTy, global.constant(), linkage, global.sym_name(), + global, arrayTy, global.constant(), linkage, global.sym_name(), initialValue, type.getMemorySpace()); return success(); } @@ -2491,7 +2484,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(RsqrtOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RsqrtOp::Adaptor transformed(operands); auto operandType = @@ -2500,8 +2493,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { if (!operandType) return failure(); - auto loc = op->getLoc(); - auto resultType = *op->result_type_begin(); + auto loc = op.getLoc(); + auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); @@ -2524,7 +2517,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { return failure(); return handleMultidimensionalVectors( - op, operands, typeConverter, + op.getOperation(), operands, typeConverter, [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, @@ -2543,8 +2536,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult match(Operation *op) const override { - auto memRefCastOp = cast(op); + LogicalResult match(MemRefCastOp memRefCastOp) const override { Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); @@ -2568,19 +2560,18 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { : failure(); } - void rewrite(Operation *op, ArrayRef operands, + void rewrite(MemRefCastOp memRefCastOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto memRefCastOp = cast(op); MemRefCastOp::Adaptor transformed(operands); auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); auto targetStructType = typeConverter.convertType(memRefCastOp.getType()); - auto loc = op->getLoc(); + auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. if (srcType.isa() && dstType.isa()) - return rewriter.replaceOp(op, {transformed.source()}); + return rewriter.replaceOp(memRefCastOp, {transformed.source()}); if (srcType.isa() && dstType.isa()) { // Casting ranked to unranked memref type @@ -2607,7 +2598,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); - rewriter.replaceOp(op, (Value)memRefDesc); + rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. @@ -2625,7 +2616,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { .getResult(); // struct = LoadOp castPtr auto loadOp = rewriter.create(loc, castPtr); - rewriter.replaceOp(op, loadOp.getResult()); + rewriter.replaceOp(memRefCastOp, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } @@ -2680,17 +2671,17 @@ struct MemRefReinterpretCastOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto castOp = cast(op); - MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary()); + MemRefReinterpretCastOp::Adaptor adaptor( + operands, castOp.getOperation()->getAttrDictionary()); Type srcType = castOp.source().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, adaptor, &descriptor))) return failure(); - rewriter.replaceOp(op, {descriptor}); + rewriter.replaceOp(castOp, {descriptor}); return success(); } @@ -2748,10 +2739,9 @@ struct MemRefReshapeOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto reshapeOp = cast(op); - + auto *op = reshapeOp.getOperation(); MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); Type srcType = reshapeOp.source().getType(); @@ -2898,15 +2888,14 @@ struct DialectCastOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(LLVM::DialectCastOp castOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto castOp = cast(op); LLVM::DialectCastOp::Adaptor transformed(operands); if (transformed.in().getType() != typeConverter.convertType(castOp.getType())) { return failure(); } - rewriter.replaceOp(op, transformed.in()); + rewriter.replaceOp(castOp, transformed.in()); return success(); } }; @@ -2917,19 +2906,18 @@ struct DimOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(DimOp dimOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto dimOp = cast(op); Type operandType = dimOp.memrefOrTensor().getType(); if (operandType.isa()) { - rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp, - operands, rewriter)}); + rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef( + operandType, dimOp, operands, rewriter)}); return success(); } if (operandType.isa()) { - rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp, - operands, rewriter)}); + rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef( + operandType, dimOp, operands, rewriter)}); return success(); } return failure(); @@ -3006,10 +2994,10 @@ struct RankOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(RankOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - Type operandType = cast(op).memrefOrTensor().getType(); + Location loc = op.getLoc(); + Type operandType = op.memrefOrTensor().getType(); if (auto unrankedMemRefType = operandType.dyn_cast()) { UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); @@ -3033,8 +3021,8 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::isSupportedMemRefType; using Base = LoadStoreOpLowering; - LogicalResult match(Operation *op) const override { - MemRefType type = cast(op).getMemRefType(); + LogicalResult match(Derived op) const override { + MemRefType type = op.getMemRefType(); return isSupportedMemRefType(type) ? success() : failure(); } }; @@ -3045,16 +3033,15 @@ struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loadOp = cast(op); LoadOp::Adaptor transformed(operands); auto type = loadOp.getMemRefType(); Value dataPtr = - getStridedElementPtr(op->getLoc(), type, transformed.memref(), + getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(), transformed.indices(), rewriter); - rewriter.replaceOpWithNewOp(op, dataPtr); + rewriter.replaceOpWithNewOp(loadOp, dataPtr); return success(); } }; @@ -3065,13 +3052,13 @@ struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(StoreOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto type = cast(op).getMemRefType(); + auto type = op.getMemRefType(); StoreOp::Adaptor transformed(operands); Value dataPtr = - getStridedElementPtr(op->getLoc(), type, transformed.memref(), + getStridedElementPtr(op.getLoc(), type, transformed.memref(), transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); @@ -3085,29 +3072,26 @@ struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(PrefetchOp prefetchOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto prefetchOp = cast(op); PrefetchOp::Adaptor transformed(operands); auto type = prefetchOp.getMemRefType(); + auto loc = prefetchOp.getLoc(); - Value dataPtr = - getStridedElementPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(), + transformed.indices(), rewriter); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create( - op->getLoc(), llvmI32Type, - rewriter.getI32IntegerAttr(prefetchOp.isWrite())); + loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); auto localityHint = rewriter.create( - op->getLoc(), llvmI32Type, + loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.localityHint())); auto isData = rewriter.create( - op->getLoc(), llvmI32Type, - rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); + loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); - rewriter.replaceOpWithNewOp(op, dataPtr, isWrite, + rewriter.replaceOpWithNewOp(prefetchOp, dataPtr, isWrite, localityHint, isData); return success(); } @@ -3121,10 +3105,9 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(IndexCastOp indexCastOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpAdaptor transformed(operands); - auto indexCastOp = cast(op); auto targetType = this->typeConverter.convertType(indexCastOp.getResult().getType()) @@ -3134,12 +3117,12 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern { unsigned sourceBits = sourceType.getIntegerBitWidth(); if (targetBits == sourceBits) - rewriter.replaceOp(op, transformed.in()); + rewriter.replaceOp(indexCastOp, transformed.in()); else if (targetBits < sourceBits) - rewriter.replaceOpWithNewOp(op, targetType, + rewriter.replaceOpWithNewOp(indexCastOp, targetType, transformed.in()); else - rewriter.replaceOpWithNewOp(op, targetType, + rewriter.replaceOpWithNewOp(indexCastOp, targetType, transformed.in()); return success(); } @@ -3156,13 +3139,12 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(CmpIOp cmpiOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto cmpiOp = cast(op); CmpIOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( - op, typeConverter.convertType(cmpiOp.getResult().getType()), + cmpiOp, typeConverter.convertType(cmpiOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); @@ -3175,13 +3157,12 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(CmpFOp cmpfOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto cmpfOp = cast(op); CmpFOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( - op, typeConverter.convertType(cmpfOp.getResult().getType()), + cmpfOp, typeConverter.convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); @@ -3243,10 +3224,10 @@ struct OneToOneLLVMTerminatorLowering using Super = OneToOneLLVMTerminatorLowering; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), - op->getAttrs()); + rewriter.replaceOpWithNewOp( + op, operands, op.getOperation()->getSuccessors(), op.getAttrs()); return success(); } }; @@ -3261,16 +3242,16 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(ReturnOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - unsigned numArguments = op->getNumOperands(); + Location loc = op.getLoc(); + unsigned numArguments = op.getNumOperands(); SmallVector updatedOperands; if (typeConverter.getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. - for (auto it : llvm::zip(op->getOperands(), operands)) { + for (auto it : llvm::zip(op.getOperation()->getOperands(), operands)) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); if (oldTy.isa()) { @@ -3286,26 +3267,26 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { } else { updatedOperands = llvm::to_vector<4>(operands); copyUnrankedDescriptors(rewriter, loc, typeConverter, - op->getOperands().getTypes(), updatedOperands, + op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); } // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), - op->getAttrs()); + op.getAttrs()); return success(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( - op, TypeRange(), updatedOperands, op->getAttrs()); + op, TypeRange(), updatedOperands, op.getAttrs()); return success(); } // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. auto packedType = typeConverter.packFunctionResults( - llvm::to_vector<4>(op->getOperandTypes())); + llvm::to_vector<4>(op.getOperandTypes())); Value packed = rewriter.create(loc, packedType); for (unsigned i = 0; i < numArguments; ++i) { @@ -3314,7 +3295,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, - op->getAttrs()); + op.getAttrs()); return success(); } }; @@ -3335,29 +3316,30 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SplatOp splatOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto splatOp = cast(op); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter.convertType(splatOp.getType()); - Value undef = rewriter.create(op->getLoc(), vectorType); + Value undef = rewriter.create(splatOp.getLoc(), vectorType); auto zero = rewriter.create( - op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)), + splatOp.getLoc(), + typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( - op->getLoc(), vectorType, undef, splatOp.getOperand(), zero); + splatOp.getLoc(), vectorType, undef, splatOp.getOperand(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); - rewriter.replaceOpWithNewOp(op, v, undef, zeroAttrs); + rewriter.replaceOpWithNewOp(splatOp, v, undef, + zeroAttrs); return success(); } }; @@ -3369,16 +3351,15 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SplatOp splatOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto splatOp = cast(op); SplatOp::Adaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return failure(); // First insert it into an undef vector so we can shuffle it. - auto loc = op->getLoc(); + auto loc = splatOp.getLoc(); auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; @@ -3409,7 +3390,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern { desc = rewriter.create(loc, llvmArrayTy, desc, v, position); }); - rewriter.replaceOp(op, desc); + rewriter.replaceOp(splatOp, desc); return success(); } }; @@ -3431,10 +3412,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SubViewOp subViewOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto subViewOp = cast(op); + auto loc = subViewOp.getLoc(); auto sourceMemRefType = subViewOp.source().getType().cast(); auto sourceElementTy = @@ -3545,7 +3525,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { j--; } - rewriter.replaceOp(op, {targetMemRef}); + rewriter.replaceOp(subViewOp, {targetMemRef}); return success(); } }; @@ -3562,16 +3542,15 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(TransposeOp transposeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); + auto loc = transposeOp.getLoc(); TransposeOpAdaptor adaptor(operands); MemRefDescriptor viewMemRef(adaptor.in()); - auto transposeOp = cast(op); // No permutation, early exit. if (transposeOp.permutation().isIdentity()) - return rewriter.replaceOp(op, {viewMemRef}), success(); + return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::undef( rewriter, loc, typeConverter.convertType(transposeOp.getShapedType())); @@ -3596,7 +3575,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern { viewMemRef.stride(rewriter, loc, sourcePos)); } - rewriter.replaceOp(op, {targetMemRef}); + rewriter.replaceOp(transposeOp, {targetMemRef}); return success(); } }; @@ -3643,10 +3622,9 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { } LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(ViewOp viewOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto viewOp = cast(op); + auto loc = viewOp.getLoc(); ViewOpAdaptor adaptor(operands); auto viewMemRefType = viewOp.getType(); @@ -3656,14 +3634,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { auto targetDescTy = typeConverter.convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) - return op->emitWarning("Target descriptor type not converted to LLVM"), + return viewOp.emitWarning("Target descriptor type not converted to LLVM"), failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) - return op->emitWarning("cannot cast to non-strided shape"), failure(); + return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); assert(offset == 0 && "expected offset to be 0"); // Create the descriptor. @@ -3695,11 +3673,12 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) - return rewriter.replaceOp(op, {targetMemRef}), success(); + return rewriter.replaceOp(viewOp, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) - return op->emitWarning("cannot cast to non-contiguous shape"), failure(); + return viewOp.emitWarning("cannot cast to non-contiguous shape"), + failure(); Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. @@ -3712,7 +3691,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { nextSize = size; } - rewriter.replaceOp(op, {targetMemRef}); + rewriter.replaceOp(viewOp, {targetMemRef}); return success(); } }; @@ -3722,11 +3701,12 @@ struct AssumeAlignmentOpLowering using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(AssumeAlignmentOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { AssumeAlignmentOp::Adaptor transformed(operands); Value memref = transformed.memref(); - unsigned alignment = cast(op).alignment(); + unsigned alignment = op.alignment(); + auto loc = op.getLoc(); MemRefDescriptor memRefDescriptor(memref); Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); @@ -3741,16 +3721,14 @@ struct AssumeAlignmentOpLowering // pointer SSA value. auto intPtrType = getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); - Value zero = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0); - Value mask = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, - alignment - 1); - Value ptrValue = - rewriter.create(op->getLoc(), intPtrType, ptr); + Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); + Value mask = + createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); + Value ptrValue = rewriter.create(loc, intPtrType, ptr); rewriter.create( - op->getLoc(), - rewriter.create( - op->getLoc(), LLVM::ICmpPredicate::eq, - rewriter.create(op->getLoc(), ptrValue, mask), zero)); + loc, rewriter.create( + loc, LLVM::ICmpPredicate::eq, + rewriter.create(loc, ptrValue, mask), zero)); rewriter.eraseOp(op); return success(); @@ -3789,9 +3767,10 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto atomicOp = cast(op); + if (failed(match(atomicOp))) + return failure(); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); @@ -3799,10 +3778,10 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering { auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = - getStridedElementPtr(op->getLoc(), memRefType, adaptor.memref(), + getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( - op, resultType, *maybeKind, dataPtr, adaptor.value(), + atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), LLVM::AtomicOrdering::acq_rel); return success(); } @@ -3840,11 +3819,10 @@ struct GenericAtomicRMWOpLowering using Base::Base; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto atomicOp = cast(op); - auto loc = op->getLoc(); + auto loc = atomicOp.getLoc(); GenericAtomicRMWOp::Adaptor adaptor(operands); LLVM::LLVMType valueType = typeConverter.convertType(atomicOp.getResult().getType()) @@ -3908,7 +3886,7 @@ struct GenericAtomicRMWOpLowering std::next(opsToMoveEnd), rewriter); // The 'result' of the atomic_rmw op is the newly loaded value. - rewriter.replaceOp(op, {newLoaded}); + rewriter.replaceOp(atomicOp, {newLoaded}); return success(); } diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp index a612738c5dccef..61062c7938fe26 100644 --- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp +++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp @@ -25,7 +25,7 @@ class TestTypeProducerOpConverter test::TestTypeProducerOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(test::TestTypeProducerOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, getVoidPtrType()); return success();