diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 4481e56615b8bf1..5c82c041c62eeb7 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -206,7 +206,9 @@ def AnyRegion : Region, "any region">; // A region with the given number of blocks. class SizedRegion : Region< CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">, - "region with " # numBlocks # " blocks">; + "region with " # numBlocks # " blocks"> { + int blocks = numBlocks; +} // A region with at least the given number of blocks. class MinSizedRegion : Region< diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td index 1ba84a5d3683d41..7f4815d865b60be 100644 --- a/mlir/test/tblgen-to-irdl/TestDialect.td +++ b/mlir/test/tblgen-to-irdl/TestDialect.td @@ -106,6 +106,17 @@ def Test_OrOp : Test_Op<"or"> { // CHECK-NEXT: irdl.operands(%[[v3]]) // CHECK-NEXT: } +// Check regions are converted correctly. +def Test_RegionsOp : Test_Op<"regions"> { + let regions = (region AnyRegion:$any_region, + SizedRegion<1>:$single_block_region); +} +// CHECK-LABEL: irdl.operation @regions { +// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.region +// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.region with size 1 +// CHECK-NEXT: irdl.regions(%[[v0]], %[[v1]]) +// CHECK-NEXT: } + // Check that various types are converted correctly. def Test_TypesOp : Test_Op<"types"> { let arguments = (ins I32:$a, diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp index d0a3552fb123daf..79ff919f634b02d 100644 --- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp +++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp @@ -338,6 +338,29 @@ Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) { return createPredicate(builder, constraint.getPredicate()); } +Value createRegionConstraint(OpBuilder &builder, tblgen::Region constraint) { + MLIRContext *ctx = builder.getContext(); + const Record &predRec = constraint.getDef(); + + if (predRec.getName() == "AnyRegion") { + ValueRange entryBlockArgs = {}; + auto op = + builder.create(UnknownLoc::get(ctx), entryBlockArgs); + return op.getResult(); + } + + if (predRec.isSubClassOf("SizedRegion")) { + ValueRange entryBlockArgs = {}; + auto ty = IntegerType::get(ctx, 32); + auto op = builder.create( + UnknownLoc::get(ctx), entryBlockArgs, + IntegerAttr::get(ty, predRec.getValueAsInt("blocks"))); + return op.getResult(); + } + + return createPredicate(builder, constraint.getPredicate()); +} + /// Returns the name of the operation without the dialect prefix. static StringRef getOperatorName(tblgen::Operator &tblgenOp) { StringRef opName = tblgenOp.getDef().getValueAsString("opName"); @@ -404,6 +427,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder, attrNames.push_back(StringAttr::get(ctx, namedAttr.name)); } + SmallVector regions; + for (auto namedRegion : tblgenOp.getRegions()) { + regions.push_back( + createRegionConstraint(consBuilder, namedRegion.constraint)); + } + // Create the operands and results operations. if (!operands.empty()) consBuilder.create(UnknownLoc::get(ctx), operands, @@ -414,6 +443,8 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder, if (!attributes.empty()) consBuilder.create(UnknownLoc::get(ctx), attributes, ArrayAttr::get(ctx, attrNames)); + if (!regions.empty()) + consBuilder.create(UnknownLoc::get(ctx), regions); return op; }