diff --git a/include/circt/Dialect/Moore/MooreDialect.td b/include/circt/Dialect/Moore/MooreDialect.td index 590e59da6068..a125b7e260bd 100644 --- a/include/circt/Dialect/Moore/MooreDialect.td +++ b/include/circt/Dialect/Moore/MooreDialect.td @@ -39,7 +39,7 @@ def MooreDialect : Dialect { let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 0; let hasConstantMaterializer = 1; - let dependentDialects = ["hw::HWDialect"]; + let dependentDialects = ["hw::HWDialect", "mlir::func::FuncDialect"]; } #endif // CIRCT_DIALECT_MOORE_MOOREDIALECT diff --git a/include/circt/Dialect/Moore/MooreOps.h b/include/circt/Dialect/Moore/MooreOps.h index e9f2773eab66..6fa19c29f39e 100644 --- a/include/circt/Dialect/Moore/MooreOps.h +++ b/include/circt/Dialect/Moore/MooreOps.h @@ -17,6 +17,7 @@ #include "circt/Dialect/Moore/MooreAttributes.h" #include "circt/Dialect/Moore/MooreDialect.h" #include "circt/Dialect/Moore/MooreTypes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/include/circt/Dialect/Moore/MooreOps.td b/include/circt/Dialect/Moore/MooreOps.td index 584369c67ade..ed1121c14306 100644 --- a/include/circt/Dialect/Moore/MooreOps.td +++ b/include/circt/Dialect/Moore/MooreOps.td @@ -339,6 +339,7 @@ def BlockingAssignOp : AssignOpBase<"blocking_assign", [ Arg:$dst, UnpackedType:$src ); + let hasCanonicalizeMethod = true; } def NonBlockingAssignOp : AssignOpBase<"nonblocking_assign"> { diff --git a/lib/Dialect/Moore/CMakeLists.txt b/lib/Dialect/Moore/CMakeLists.txt index f8559fa4f2ad..941d420ec110 100644 --- a/lib/Dialect/Moore/CMakeLists.txt +++ b/lib/Dialect/Moore/CMakeLists.txt @@ -19,6 +19,7 @@ add_circt_dialect_library(CIRCTMoore CIRCTHW CIRCTSupport MLIRIR + MLIRFuncDialect MLIRInferTypeOpInterface MLIRMemorySlotInterfaces ) diff --git a/lib/Dialect/Moore/MooreDialect.cpp b/lib/Dialect/Moore/MooreDialect.cpp index 64341d36e39a..72f42265929e 100644 --- a/lib/Dialect/Moore/MooreDialect.cpp +++ b/lib/Dialect/Moore/MooreDialect.cpp @@ -12,6 +12,7 @@ #include "circt/Dialect/HW/HWDialect.h" #include "circt/Dialect/Moore/MooreOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" using namespace circt; using namespace circt::moore; diff --git a/lib/Dialect/Moore/MooreOps.cpp b/lib/Dialect/Moore/MooreOps.cpp index 609b2bf14c65..90dcc10bb0e4 100644 --- a/lib/Dialect/Moore/MooreOps.cpp +++ b/lib/Dialect/Moore/MooreOps.cpp @@ -17,7 +17,6 @@ #include "circt/Support/CustomDirectiveImpl.h" #include "mlir/IR/Builders.h" #include "llvm/ADT/SmallString.h" -#include "llvm/ADT/TypeSwitch.h" using namespace circt; using namespace circt::moore; @@ -723,7 +722,11 @@ static std::optional getStructFieldIndex(Type type, StringAttr name) { return structType.getFieldIndex(name); if (auto structType = dyn_cast(type)) return structType.getFieldIndex(name); - assert(0 && "expected StructType or UnpackedStructType"); + if (auto unionType = dyn_cast(type)) + return unionType.getFieldIndex(name); + if (auto unionType = dyn_cast(type)) + return unionType.getFieldIndex(name); + assert(0 && "expected Struct-Like Type"); return {}; } @@ -732,7 +735,10 @@ static ArrayRef getStructMembers(Type type) { return structType.getMembers(); if (auto structType = dyn_cast(type)) return structType.getMembers(); - assert(0 && "expected StructType or UnpackedStructType"); + if (auto unionType = dyn_cast(type)) + return unionType.getMembers(); + if (auto unionType = dyn_cast(type)) + return unionType.getMembers(); return {}; } @@ -927,23 +933,18 @@ LogicalResult StructInjectOp::canonicalize(StructInjectOp op, //===----------------------------------------------------------------------===// LogicalResult UnionCreateOp::verify() { - /// checks if the types of the input is exactly equal to the union field + auto type = getStructFieldType(getType(), getFieldNameAttr()); + + /// checks if the type of the input is exactly equal to the union field /// type - return TypeSwitch(getType()) - .Case([this](auto &type) { - auto members = type.getMembers(); - auto resultType = getType(); - auto fieldName = getFieldName(); - for (const auto &member : members) - if (member.name == fieldName && member.type == resultType) - return success(); - emitOpError("input type must match the union field type"); - return failure(); - }) - .Default([this](auto &) { - emitOpError("input type must be UnionType or UnpackedUnionType"); - return failure(); - }); + + if (!type) + return emitOpError() << "union field " << getFieldNameAttr() + << " which does not exist in " << getInput().getType(); + if (type != getType()) + return emitOpError() << "result type " << getType() + << " must match union field type " << type; + return success(); } //===----------------------------------------------------------------------===// @@ -951,47 +952,35 @@ LogicalResult UnionCreateOp::verify() { //===----------------------------------------------------------------------===// LogicalResult UnionExtractOp::verify() { - /// checks if the types of the input is exactly equal to the one of the - /// types of the result union fields - return TypeSwitch(getInput().getType()) - .Case([this](auto &type) { - auto members = type.getMembers(); - auto fieldName = getFieldName(); - auto resultType = getType(); - for (const auto &member : members) - if (member.name == fieldName && member.type == resultType) - return success(); - emitOpError("result type must match the union field type"); - return failure(); - }) - .Default([this](auto &) { - emitOpError("input type must be UnionType or UnpackedUnionType"); - return failure(); - }); + auto type = getStructFieldType(getInput().getType(), getFieldNameAttr()); + + /// checks if the type of the input is exactly equal to the type of the result + /// union fields + if (!type) + return emitOpError() << "union field " << getFieldNameAttr() + << " which does not exist in " << getInput().getType(); + if (type != getType()) + return emitOpError() << "result type " << getType() + << " must match union field type " << type; + return success(); } //===----------------------------------------------------------------------===// -// UnionExtractOp +// UnionExtractRefOp //===----------------------------------------------------------------------===// LogicalResult UnionExtractRefOp::verify() { - /// checks if the types of the result is exactly equal to the type of the - /// refe union field - return TypeSwitch(getInput().getType().getNestedType()) - .Case([this](auto &type) { - auto members = type.getMembers(); - auto fieldName = getFieldName(); - auto resultType = getType().getNestedType(); - for (const auto &member : members) - if (member.name == fieldName && member.type == resultType) - return success(); - emitOpError("result type must match the union field type"); - return failure(); - }) - .Default([this](auto &) { - emitOpError("input type must be UnionType or UnpackedUnionType"); - return failure(); - }); + auto type = getStructFieldType(getInput().getType().getNestedType(), + getFieldNameAttr()); + /// checks if the type of the result is exactly equal to the type of the + /// referring union field + if (!type) + return emitOpError() << "union field " << getFieldNameAttr() + << " which does not exist in " << getInput().getType(); + if (type != getType()) + return emitOpError() << "result type " << getType() + << " must match union field type " << type; + return success(); } //===----------------------------------------------------------------------===// @@ -1092,6 +1081,20 @@ DeletionKind BlockingAssignOp::removeBlockingUses( return DeletionKind::Delete; } +LogicalResult BlockingAssignOp::canonicalize(BlockingAssignOp op, + PatternRewriter &rewriter) { + if (auto refOp = op.getDst().getDefiningOp()) { + auto input = refOp.getInput(); + auto read = rewriter.create(op->getLoc(), input); + auto value = op.getSrc(); + auto inject = rewriter.create( + op->getLoc(), read, refOp.getFieldNameAttr(), value); + op->setOperands({input, inject}); + return success(); + } + return failure(); +} + //===----------------------------------------------------------------------===// // ReadOp //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Moore/canonicalizers.mlir b/test/Dialect/Moore/canonicalizers.mlir index 2754fcef860c..9ed309fe7383 100644 --- a/test/Dialect/Moore/canonicalizers.mlir +++ b/test/Dialect/Moore/canonicalizers.mlir @@ -192,6 +192,19 @@ func.func @StructExtractFold2(%arg0: !moore.i17, %arg1: !moore.i42) -> (!moore.i return %1, %2 : !moore.i17, !moore.i42 } +// CHECK-LABEL: func.func @structExtractRefLower2Inject +func.func @structExtractRefLower2Inject(%arg0: !moore.ref>) -> (!moore.ref>) { + // CHECK-NEXT: [[C42:%.+]] = moore.constant 42 + // CHECK-NEXT: [[TMP1:%.+]] = moore.read %arg0 : > + // CHECK-NEXT: [[TMP2:%.+]] = moore.struct_inject [[TMP1]], "a", [[C42]] : struct<{a: i32, b: i32}>, i32 + // CHECK-NEXT: moore.blocking_assign %arg0, [[TMP2]] : struct<{a: i32, b: i32}> + // CHECK-NEXT: return %arg0 : !moore.ref> + %0 = moore.constant 42 : i32 + %1 = moore.struct_extract_ref %arg0, "a" : > -> + moore.blocking_assign %1, %0 : i32 + return %arg0 : !moore.ref> +} + // CHECK-LABEL: func.func @StructInjectFold1 func.func @StructInjectFold1(%arg0: !moore.struct<{a: i32, b: i32}>) -> (!moore.struct<{a: i32, b: i32}>) { // CHECK-NEXT: [[C42:%.+]] = moore.constant 42