diff --git a/src/Conversion/KrnlToAffine/KrnlToAffine.cpp b/src/Conversion/KrnlToAffine/KrnlToAffine.cpp index ee42626b9e0..dcdf189399c 100644 --- a/src/Conversion/KrnlToAffine/KrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/KrnlToAffine.cpp @@ -432,13 +432,13 @@ void lowerGetInductionVariableValueOp(KrnlGetInductionVariableValueOp &getIVOp, /// At this stage the dialect will contain standard operations as well like /// add and multiply, this pass will leave these operations intact. struct ConvertKrnlToAffinePass - : public PassWrapper { + : public PassWrapper> { StringRef getArgument() const override { return "convert-krnl-to-affine"; } StringRef getDescription() const override { return "Lower Krnl dialect."; } - void runOnFunction() final; + void runOnOperation() final; }; LogicalResult interpretOperation(Operation *op, OpBuilder &builder, @@ -1443,9 +1443,15 @@ void markLoopBodyAsMovable( } } -void ConvertKrnlToAffinePass::runOnFunction() { +void ConvertKrnlToAffinePass::runOnOperation() { OpBuilder builder(&getContext()); - FuncOp funcOp = getFunction(); + FuncOp funcOp = getOperation(); + + // external function: nothing to do + if (funcOp.body().empty()) { + return; + } + // Move invariant instructions outside of the loops as many as possible. This // helps make loops perfectly nested, which facilitates transformations. funcOp.walk([&](KrnlIterateOp loopOp) { @@ -1530,7 +1536,7 @@ void ConvertKrnlToAffinePass::runOnFunction() { DenseSet unconverted; if (failed(applyPartialConversion( - getFunction(), target, std::move(patterns), &unconverted))) { + getOperation(), target, std::move(patterns), &unconverted))) { { const std::lock_guard lock(unrollAndJamMutex); unrollAndJamMap.erase(currFuncOp); diff --git a/src/Transform/BundleMemoryPools.cpp b/src/Transform/BundleMemoryPools.cpp index d26ad2fe240..be7be4d7202 100644 --- a/src/Transform/BundleMemoryPools.cpp +++ b/src/Transform/BundleMemoryPools.cpp @@ -506,7 +506,7 @@ class KrnlMoveConstantsUp : public OpRewritePattern { */ class KrnlBundleMemoryPoolsPass - : public PassWrapper { + : public PassWrapper> { BlockToMemPool blockToStaticPool; BlockToMemPool blockToDynamicPool; @@ -518,8 +518,8 @@ class KrnlBundleMemoryPoolsPass return "Bundle memory pools of internal MemRefs into a single memory pool."; } - void runOnFunction() override { - auto function = getFunction(); + void runOnOperation() override { + auto function = getOperation(); ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); diff --git a/src/Transform/DisconnectKrnlDimFromAlloc.cpp b/src/Transform/DisconnectKrnlDimFromAlloc.cpp index 0290fb5aa85..f48d2f0aa85 100644 --- a/src/Transform/DisconnectKrnlDimFromAlloc.cpp +++ b/src/Transform/DisconnectKrnlDimFromAlloc.cpp @@ -129,7 +129,8 @@ class DisconnectKrnlDimFromAlloc : public OpRewritePattern { * Function pass that disconnects krnl.dim emission from its MemRef alloc. */ class DisconnectKrnlDimFromAllocPass - : public PassWrapper { + : public PassWrapper> { public: StringRef getArgument() const override { return "lower-krnl-shape-to-std"; } @@ -137,8 +138,8 @@ class DisconnectKrnlDimFromAllocPass return "Lowers krnl shape-related operations."; } - void runOnFunction() override { - auto function = getFunction(); + void runOnOperation() override { + auto function = getOperation(); ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); diff --git a/src/Transform/ElideKrnlGlobalConstants.cpp b/src/Transform/ElideKrnlGlobalConstants.cpp index 0361a092125..9d06ce6cdc8 100644 --- a/src/Transform/ElideKrnlGlobalConstants.cpp +++ b/src/Transform/ElideKrnlGlobalConstants.cpp @@ -80,7 +80,7 @@ namespace { * Function pass that performs constant value elision of Krnl globals. */ class ElideConstGlobalValuePass - : public PassWrapper { + : public PassWrapper> { public: StringRef getArgument() const override { return "elide-krnl-constants"; } @@ -88,8 +88,8 @@ class ElideConstGlobalValuePass return "Elide the constant values of the Global Krnl operations."; } - void runOnFunction() override { - auto function = getFunction(); + void runOnOperation() override { + auto function = getOperation(); ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); diff --git a/src/Transform/EnableMemoryPool.cpp b/src/Transform/EnableMemoryPool.cpp index 04f68ad4a0f..6ee57721bcf 100644 --- a/src/Transform/EnableMemoryPool.cpp +++ b/src/Transform/EnableMemoryPool.cpp @@ -185,7 +185,7 @@ class KrnlEliminateOldDealloc : public OpRewritePattern { * Function pass that enables memory pooling for MemRefs. */ class KrnlEnableMemoryPoolPass - : public PassWrapper { + : public PassWrapper> { public: StringRef getArgument() const override { return "enable-memory-pool"; } @@ -193,8 +193,8 @@ class KrnlEnableMemoryPoolPass return "Enable a memory pool for allocating internal MemRefs."; } - void runOnFunction() override { - auto function = getFunction(); + void runOnOperation() override { + auto function = getOperation(); ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); diff --git a/src/Transform/LowerKrnlShape.cpp b/src/Transform/LowerKrnlShape.cpp index 9ed61b8523e..781198acf85 100644 --- a/src/Transform/LowerKrnlShape.cpp +++ b/src/Transform/LowerKrnlShape.cpp @@ -78,7 +78,7 @@ class LowerKrnlShape : public OpRewritePattern { * Function pass that emits the shape of a MemRef. */ class LowerKrnlShapePass - : public PassWrapper { + : public PassWrapper> { public: StringRef getArgument() const override { return "lower-krnl-shape"; } @@ -86,8 +86,8 @@ class LowerKrnlShapePass return "Lower krnl.shape operation to use Shape dialect operations."; } - void runOnFunction() override { - auto function = getFunction(); + void runOnOperation() override { + auto function = getOperation(); ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index 8cca4883e85..f7d7546ab34 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -745,7 +745,7 @@ class ConstPropScatterNDPattern : public OpRewritePattern { //===----------------------------------------------------------------------===// struct ConstPropONNXToONNXPass - : public PassWrapper { + : public PassWrapper> { StringRef getArgument() const override { return "constprop-onnx"; } @@ -754,12 +754,12 @@ struct ConstPropONNXToONNXPass "other ONNX operations."; } - void runOnFunction() final; + void runOnOperation() final; }; } // end anonymous namespace. -void ConstPropONNXToONNXPass::runOnFunction() { - auto function = getFunction(); +void ConstPropONNXToONNXPass::runOnOperation() { + auto function = getOperation(); MLIRContext *context = &getContext(); ConversionTarget target(getContext()); diff --git a/src/Transform/ONNX/Decompose.cpp b/src/Transform/ONNX/Decompose.cpp index 95ddebdb854..24509e7688d 100644 --- a/src/Transform/ONNX/Decompose.cpp +++ b/src/Transform/ONNX/Decompose.cpp @@ -123,7 +123,7 @@ Value createSequenceConstructOp( namespace { struct DecomposeONNXToONNXPass - : public PassWrapper { + : public PassWrapper> { StringRef getArgument() const override { return "decompose-onnx"; } @@ -132,12 +132,12 @@ struct DecomposeONNXToONNXPass "operations."; } - void runOnFunction() final; + void runOnOperation() final; }; } // end anonymous namespace. -void DecomposeONNXToONNXPass::runOnFunction() { - auto function = getFunction(); +void DecomposeONNXToONNXPass::runOnOperation() { + auto function = getOperation(); MLIRContext *context = &getContext(); ConversionTarget target(getContext()); diff --git a/src/Transform/ONNX/ElideConstants.cpp b/src/Transform/ONNX/ElideConstants.cpp index 015914fa113..dc1a458311c 100644 --- a/src/Transform/ONNX/ElideConstants.cpp +++ b/src/Transform/ONNX/ElideConstants.cpp @@ -64,7 +64,7 @@ class ConstantValueElision : public OpRewritePattern { * Function pass that performs constant value elision. */ class ElideConstantValuePass - : public PassWrapper { + : public PassWrapper> { public: StringRef getArgument() const override { return "elide-constants"; } @@ -72,8 +72,8 @@ class ElideConstantValuePass return "Elide values of constant operations."; } - void runOnFunction() override { - auto function = getFunction(); + void runOnOperation() override { + auto function = getOperation(); ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); diff --git a/src/Transform/ONNX/InstrumentONNXPass.cpp b/src/Transform/ONNX/InstrumentONNXPass.cpp index b13310772e5..62f41c29f87 100644 --- a/src/Transform/ONNX/InstrumentONNXPass.cpp +++ b/src/Transform/ONNX/InstrumentONNXPass.cpp @@ -57,7 +57,7 @@ llvm::cl::bits InstrumentControlBits( llvm::cl::cat(OMPassOptions)); class InstrumentONNXPass - : public mlir::PassWrapper { + : public mlir::PassWrapper> { private: bool allOpsAllowed; @@ -84,13 +84,13 @@ class InstrumentONNXPass runtimeActions = InstrumentControlBits.getBits(); }; - void runOnFunction() override { + void runOnOperation() override { if (instrumentONNXOps == "" || instrumentONNXOps == "NONE") return; init(instrumentONNXOps); // Iterate on the operations nested in this function - getFunction().walk([&](mlir::Operation *op) { + getOperation().walk([&](mlir::Operation *op) { if (isa(op->getDialect())) { // Skip the prefix "onnx." of onnx op name const char *opName = op->getName().getStringRef().data() + 5; diff --git a/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp b/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp index ade2ebb4ebf..5b5937b5273 100644 --- a/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp +++ b/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp @@ -38,15 +38,15 @@ namespace { */ class ONNXPreKrnlVerifyPass - : public mlir::PassWrapper { + : public mlir::PassWrapper> { public: StringRef getArgument() const override { return "onnx-pre-krnl-verify"; } StringRef getDescription() const override { return "Verify onnx ops."; } - void runOnFunction() override { - auto function = getFunction(); + void runOnOperation() override { + auto function = getOperation(); auto &funcBody = function.getBody(); // Iterate on the operations diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 79b9762019f..4f3485e2389 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -41,7 +41,7 @@ static SmallVector lookUpFuncsMatching( } /*! - * FunctionPass that performs shape inference by iterating over a list of + * Function pass that performs shape inference by iterating over a list of * candidate operations and propagating the shape information until the list * of operations is empty [credit MLIR authors]. * diff --git a/src/Transform/OptimizeMemoryPools.cpp b/src/Transform/OptimizeMemoryPools.cpp index 5ac5c178b36..557e9db8ece 100644 --- a/src/Transform/OptimizeMemoryPools.cpp +++ b/src/Transform/OptimizeMemoryPools.cpp @@ -868,7 +868,7 @@ class KrnlCompactStaticMemoryPools : public OpRewritePattern { * Function pass that optimizes memory pools. */ class KrnlOptimizeMemoryPoolsPass - : public PassWrapper { + : public PassWrapper> { BlockToCompactedAlignments blockToStaticPoolAlignments; BlockToDiscardedGetRefs blockToDiscardedGetRefs; @@ -879,8 +879,8 @@ class KrnlOptimizeMemoryPoolsPass return "Optimize the static and dynamic memory pools."; } - void runOnFunction() override { - auto function = getFunction(); + void runOnOperation() override { + auto function = getOperation(); ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext());