From db8f16b8ba11a2432a3cef8ea9237821019a5514 Mon Sep 17 00:00:00 2001 From: mingzheTerapines Date: Fri, 2 Aug 2024 13:31:28 +0800 Subject: [PATCH] Remove typeswitch Add func.func Dialect Disable caonicalize function for local variable --- include/circt/Dialect/Moore/MooreDialect.td | 2 +- include/circt/Dialect/Moore/MooreOps.h | 1 + lib/Dialect/Moore/CMakeLists.txt | 1 + lib/Dialect/Moore/MooreDialect.cpp | 1 + lib/Dialect/Moore/MooreOps.cpp | 107 +++++++++++--------- 5 files changed, 63 insertions(+), 49 deletions(-) diff --git a/include/circt/Dialect/Moore/MooreDialect.td b/include/circt/Dialect/Moore/MooreDialect.td index 84f57a5285f4..ffec56edd9c1 100644 --- a/include/circt/Dialect/Moore/MooreDialect.td +++ b/include/circt/Dialect/Moore/MooreDialect.td @@ -35,7 +35,7 @@ def MooreDialect : Dialect { void printType(Type, DialectAsmPrinter &) const override; }]; let useDefaultTypePrinterParser = 0; - let dependentDialects = ["hw::HWDialect"]; + let dependentDialects = ["hw::HWDialect", "mlir::func::FuncDialect"]; } #endif // CIRCT_DIALECT_MOORE_MOOREDIALECT diff --git a/include/circt/Dialect/Moore/MooreOps.h b/include/circt/Dialect/Moore/MooreOps.h index 5a27457159d8..626174184dbd 100644 --- a/include/circt/Dialect/Moore/MooreOps.h +++ b/include/circt/Dialect/Moore/MooreOps.h @@ -16,6 +16,7 @@ #include "circt/Dialect/HW/HWTypes.h" #include "circt/Dialect/Moore/MooreDialect.h" #include "circt/Dialect/Moore/MooreTypes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/lib/Dialect/Moore/CMakeLists.txt b/lib/Dialect/Moore/CMakeLists.txt index bda0ea7713cc..4cb49c6d18a2 100644 --- a/lib/Dialect/Moore/CMakeLists.txt +++ b/lib/Dialect/Moore/CMakeLists.txt @@ -18,6 +18,7 @@ add_circt_dialect_library(CIRCTMoore CIRCTHW CIRCTSupport MLIRIR + MLIRFuncDialect MLIRInferTypeOpInterface MLIRMemorySlotInterfaces ) diff --git a/lib/Dialect/Moore/MooreDialect.cpp b/lib/Dialect/Moore/MooreDialect.cpp index 902fbb27e730..3aa3a6e37bfd 100644 --- a/lib/Dialect/Moore/MooreDialect.cpp +++ b/lib/Dialect/Moore/MooreDialect.cpp @@ -12,6 +12,7 @@ #include "circt/Dialect/HW/HWDialect.h" #include "circt/Dialect/Moore/MooreOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" using namespace circt; using namespace circt::moore; diff --git a/lib/Dialect/Moore/MooreOps.cpp b/lib/Dialect/Moore/MooreOps.cpp index 7b32f031e850..5a638f012581 100644 --- a/lib/Dialect/Moore/MooreOps.cpp +++ b/lib/Dialect/Moore/MooreOps.cpp @@ -22,6 +22,8 @@ using namespace circt; using namespace circt::moore; using namespace mlir; +static ArrayRef getStructMembers(Type type); + //===----------------------------------------------------------------------===// // SVModuleOp //===----------------------------------------------------------------------===// @@ -289,54 +291,55 @@ VariableOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue, LogicalResult VariableOp::canonicalize(VariableOp op, ::mlir::PatternRewriter &rewriter) { + if (!(isa(op->getParentOp()) || + isa(op->getParentOp()))) + return failure(); + auto members = getStructMembers(op.getType().getNestedType()); + if (!members.empty()) { + SmallVector createFields; + if (auto initial = op.getInitial()) { + auto addressOp = rewriter.create( + op.getLoc(), RefType::get(cast(initial.getType())), + initial); + for (const auto &member : members) { + auto field = rewriter.create( + op->getLoc(), cast(member.type), member.name, + addressOp); + createFields.push_back(field); + } + } else { + for (const auto &member : members) { + // todo: support 4-domain value + auto field = rewriter.create(op->getLoc(), + cast(member.type), 0); + createFields.push_back(field); + } + } + auto value = rewriter.create( + op->getLoc(), op.getType().getNestedType(), createFields); + rewriter.replaceOpWithNewOp(op, RefType::get(value.getType()), + value); + return success(); + } - return TypeSwitch(op.getType().getNestedType()) - .Case([&op, &rewriter](auto &type) { - SmallVector createFields; - if (auto initial = op.getInitial()) { - auto addressOp = rewriter.create( - op.getLoc(), RefType::get(cast(initial.getType())), - initial); - for (const auto &member : type.getMembers()) { - auto field = rewriter.create( - op->getLoc(), cast(member.type), member.name, - addressOp); - createFields.push_back(field); - } - } else { - for (const auto &member : type.getMembers()) { - // todo: support 4-domain value - auto field = rewriter.create( - op->getLoc(), cast(member.type), 0); - createFields.push_back(field); - } - } - auto value = rewriter.create( - op->getLoc(), op.getType().getNestedType(), createFields); - rewriter.replaceOpWithNewOp( - op, RefType::get(value.getType()), value); - return success(); - }) - .Default([&op, &rewriter](auto &) { - Value initial; - for (auto *user : op->getUsers()) - if (isa(user) && - (user->getOperand(0) == op.getResult())) { - // Don't canonicalize the multiple continuous assignment to the same - // variable. - if (initial) - return failure(); - initial = user->getOperand(1); - } - - if (initial) { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getNameAttr(), initial); - return success(); - } - + Value initial; + for (auto *user : op->getUsers()) + if (isa(user) && + (user->getOperand(0) == op.getResult())) { + // Don't canonicalize the multiple continuous assignment to the same + // variable. + if (initial) return failure(); - }); + initial = user->getOperand(1); + } + + if (initial) { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getNameAttr(), initial); + return success(); + } + + return failure(); } SmallVector VariableOp::getDestructurableSlots() { @@ -560,7 +563,6 @@ static ArrayRef getStructMembers(Type type) { return structType.getMembers(); if (auto structType = dyn_cast(type)) return structType.getMembers(); - assert(0 && "expected StructType or UnpackedStructType"); return {}; } @@ -708,6 +710,9 @@ OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) { LogicalResult StructInjectOp::canonicalize(StructInjectOp op, PatternRewriter &rewriter) { + if (!(isa(op->getParentOp()) || + isa(op->getParentOp()))) + return failure(); auto members = getStructMembers(op.getType()); // Chase a chain of `struct_inject` ops, with an optional final @@ -906,9 +911,12 @@ DeletionKind BlockingAssignOp::removeBlockingUses( LogicalResult BlockingAssignOp::canonicalize(BlockingAssignOp op, PatternRewriter &rewriter) { + if (!(isa(op->getParentOp()) || + isa(op->getParentOp()))) + return failure(); if (auto refOp = op.getDst().getDefiningOp()) { auto input = refOp.getInput(); - if (isa(input.getDefiningOp()->getParentOp())) { + if (isa(input.getDefiningOp()->getParentOp())) { auto value = rewriter.create( op->getLoc(), cast(input.getType()).getNestedType(), input); auto newOp = rewriter.create( @@ -967,6 +975,9 @@ ReadOp::removeBlockingUses(const MemorySlot &slot, } LogicalResult ReadOp::canonicalize(ReadOp op, PatternRewriter &rewriter) { + if (!(isa(op->getParentOp()) || + isa(op->getParentOp()))) + return failure(); if (auto addr = op.getInput().getDefiningOp()) { auto value = addr.getInput(); op.replaceAllUsesWith(value);