From dbcc0d0a94ee0a33aca6555991595e9e3ea91733 Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Mon, 17 Apr 2023 12:28:47 +0100 Subject: [PATCH 01/12] Add PDLL parsing support for native constraints with results --- mlir/lib/Tools/PDLL/Parser/Parser.cpp | 6 ------ mlir/test/mlir-pdll/Parser/constraint-failure.pdll | 5 ----- mlir/test/mlir-pdll/Parser/constraint.pdll | 8 ++++++++ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 3af285e5152bfc..c1a1abb437f4b8 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -1356,12 +1356,6 @@ FailureOr Parser::parseUserNativeConstraintOrRewriteDecl( if (failed(parseToken(Token::semicolon, "expected `;` after native declaration"))) return failure(); - // TODO: PDL should be able to support constraint results in certain - // situations, we should revise this. - if (std::is_same::value && !results.empty()) { - return emitError( - "native Constraints currently do not support returning results"); - } return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); } diff --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll index 18877b4bcc50ec..48747d3fa2e681 100644 --- a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll @@ -158,8 +158,3 @@ Pattern { // CHECK: expected `;` after native declaration Constraint Foo() [{}] - -// ----- - -// CHECK: native Constraints currently do not support returning results -Constraint Foo() -> Op; diff --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll index 1c0a015ab4a7b4..e2a52ff130cc84 100644 --- a/mlir/test/mlir-pdll/Parser/constraint.pdll +++ b/mlir/test/mlir-pdll/Parser/constraint.pdll @@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }]; // ----- +// Test that native constraints support returning results. + +// CHECK: Module +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType +Constraint Foo() -> Attr; + +// ----- + // CHECK: Module // CHECK: `-UserConstraintDecl {{.*}} Name ResultType // CHECK: `Inputs` From 752a4c5b9f86ea34130646a1df61aff49e38ded7 Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Mon, 17 Apr 2023 12:58:00 +0100 Subject: [PATCH 02/12] Add PDL support for representing native constraints with results --- mlir/include/mlir/Dialect/PDL/IR/PDLOps.td | 5 +++-- mlir/test/Dialect/PDL/ops.mlir | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index c85687e199b745..047c9027056f36 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -35,7 +35,7 @@ def PDL_ApplyNativeConstraintOp let description = [{ `pdl.apply_native_constraint` operations apply a native C++ constraint, that has been registered externally with the consumer of PDL, to a given set of - entities. + entities and optionally return a number of values.. Example: @@ -46,7 +46,8 @@ def PDL_ApplyNativeConstraintOp }]; let arguments = (ins StrAttr:$name, Variadic:$args); - let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict"; + let results = (outs Variadic:$results); + let assemblyFormat = "$name `(` $args `:` type($args) `)` (`->` `(` type($results)^ `)`)? attr-dict"; let hasVerifier = 1; } diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir index c15f53d4500de9..30ac11d98c1dab 100644 --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) { // ----- +pdl.pattern @apply_constraint_with_no_results : benefit(1) { + %root = operation + apply_native_constraint "NativeConstraint"(%root : !pdl.operation) + rewrite %root with "rewriter" +} + +// ----- + +pdl.pattern @apply_constraint_with_results : benefit(1) { + %root = operation + %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) -> (!pdl.attribute) + rewrite %root { + apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute) + } +} + +// ----- + pdl.pattern @attribute_with_dict : benefit(1) { %root = operation rewrite %root { From 4e7b37c38a66cb470ee30a26d4fc295af8a9f372 Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Mon, 17 Apr 2023 14:46:58 +0100 Subject: [PATCH 03/12] Add PDL_interp support for representing native constraints with results --- .../mlir/Dialect/PDLInterp/IR/PDLInterpOps.td | 8 +++-- mlir/test/Rewrite/pdl-bytecode.mlir | 31 +++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 96d631bd474a49..0f4292bf055f98 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -88,8 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { let description = [{ `pdl_interp.apply_constraint` operations apply a generic constraint, that has been registered with the interpreter, with a given set of positional - values. On success, this operation branches to the true destination, - otherwise the false destination is taken. + values. The constraint function may return any number of results. + On success, this operation branches to the true destination, otherwise + the false destination is taken. Example: @@ -101,8 +102,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { }]; let arguments = (ins StrAttr:$name, Variadic:$args); + let results = (outs Variadic:$results); let assemblyFormat = [{ - $name `(` $args `:` type($args) `)` attr-dict `->` successors + $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors }]; } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index 565874fe0ebcee..14fd008d1ff424 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -72,6 +72,37 @@ module @ir attributes { test.apply_constraint_2 } { // ----- +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^bb0, ^end + + ^bb0: + %attr = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root, %attr : !pdl.operation, !pdl.attribute) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %attr : !pdl.attribute) { + %op = pdl_interp.create_operation "test.success" {"attr" = %attr} + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_3 +// CHECK: "test.success"() {attr = "test.success"} +module @ir attributes { test.apply_constraint_3 } { + "test.op"() : () -> () +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::ApplyRewriteOp //===----------------------------------------------------------------------===// From c082d54d4d468d7c752fdb895e9a8abf35b5a471 Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Mon, 17 Apr 2023 14:48:51 +0100 Subject: [PATCH 04/12] Add test for pdl_interp.apply_constraint with results --- mlir/test/lib/Rewrite/TestPDLByteCode.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index daa1c371f27c92..6c86c4befbb001 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -30,6 +30,15 @@ static LogicalResult customMultiEntityVariadicConstraint( return success(); } +// Custom constraint that returns a value +static LogicalResult customResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + StringAttr customAttr = rewriter.getStringAttr("test.success"); + results.push_back(customAttr); + return success(); +} + // Custom creator invoked from PDL. static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { return rewriter.create(OperationState(op->getLoc(), "test.success")); @@ -102,6 +111,9 @@ struct TestPDLByteCodePass customMultiEntityConstraint); pdlPattern.registerConstraintFunction("multi_entity_var_constraint", customMultiEntityVariadicConstraint); + pdlPattern.registerRewriteFunction("check_op_and_get_attr_constr", + customResultConstraint); + pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("var_creator", customVariadicResultCreate); From a238ae9a0920401b33d11b12ce4b75482701d23e Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Tue, 18 Apr 2023 07:20:24 +0100 Subject: [PATCH 05/12] change pdl.apply_native_constraint printing format --- mlir/include/mlir/Dialect/PDL/IR/PDLOps.td | 2 +- .../PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir | 14 ++++++++++++++ mlir/test/Dialect/PDL/ops.mlir | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 047c9027056f36..d12204b991a268 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -47,7 +47,7 @@ def PDL_ApplyNativeConstraintOp let arguments = (ins StrAttr:$name, Variadic:$args); let results = (outs Variadic:$results); - let assemblyFormat = "$name `(` $args `:` type($args) `)` (`->` `(` type($results)^ `)`)? attr-dict"; + let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict"; let hasVerifier = 1; } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index b94451c4a08689..423d9183e222a4 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -79,6 +79,20 @@ module @constraints { // ----- +// CHECK-LABEL: module @constraint_with_result +module @constraint_with_result { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } +} + +// ----- + // CHECK-LABEL: module @inputs module @inputs { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir index 30ac11d98c1dab..0f6e98fe591b43 100644 --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -144,7 +144,7 @@ pdl.pattern @apply_constraint_with_no_results : benefit(1) { pdl.pattern @apply_constraint_with_results : benefit(1) { %root = operation - %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) -> (!pdl.attribute) + %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute rewrite %root { apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute) } From ca0d7d7b8c557083cc82f29809e7ec670f07fc8d Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Tue, 18 Apr 2023 07:56:56 +0100 Subject: [PATCH 06/12] PDLToPDLInterpPass: Add support for native constraints with results --- .../PDLToPDLInterp/PDLToPDLInterp.cpp | 34 ++++++++++++- .../lib/Conversion/PDLToPDLInterp/Predicate.h | 49 ++++++++++++++++--- .../PDLToPDLInterp/PredicateTree.cpp | 12 ++++- 3 files changed, 84 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index fdc95ab7a820af..5b99894e91a9b3 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -148,6 +148,10 @@ struct PatternLowering { /// A mapping between pattern operations and the corresponding configuration /// set. DenseMap *configMap; + + /// A mapping between constraint questions that refer to values created by + /// constraints and the temporary placeholder values created for them. + DenseMap, Value> substitutions; }; } // namespace @@ -364,6 +368,20 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { loc, rawTypeAttr.cast()); break; } + case Predicates::ConstraintResultPos: { + // At this point in time the corresponding pdl.ApplyNativeConstraint op has + // been deleted and the new pdl_interp.ApplyConstraint has not been created + // yet. To enable use of results created by these operations we build a + // placeholder value that will be replaced when the actual + // pdl_interp.ApplyConstraint operation is created. + auto *constrResPos = cast(pos); + Value placeholderValue = builder.create( + loc, StringAttr::get(builder.getContext(), "placeholder")); + substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] = + placeholderValue; + value = placeholderValue; + break; + } default: llvm_unreachable("Generating unknown Position getter"); break; @@ -447,8 +465,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast(question); - builder.create(loc, cstQuestion->getName(), - args, success, failure); + auto applyConstraintOp = builder.create( + loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, + success, failure); + // Replace the generated placeholders with the results of the constraint and + // erase them + for (auto result : llvm::enumerate(applyConstraintOp.getResults())) { + std::pair substitutionKey = { + cstQuestion, result.index()}; + assert( + substitutions.count(substitutionKey) && + "expected a placeholder value for a native constraint with results"); + substitutions[substitutionKey].replaceAllUsesWith(result.value()); + substitutions[substitutionKey].getDefiningOp()->erase(); + } break; } default: diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h index 81a11529b97c29..63a2961c3a595b 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -47,6 +47,7 @@ enum Kind : unsigned { OperandPos, OperandGroupPos, AttributePos, + ConstraintResultPos, ResultPos, ResultGroupPos, TypePos, @@ -187,6 +188,25 @@ struct AttributeLiteralPosition using PredicateBase::PredicateBase; }; +//===----------------------------------------------------------------------===// +// ConstraintPosition + +struct ConstraintQuestion; + +/// A position describing the result of a native constraint. It saves the +/// corresponding ConstraintQuestion and result index to enable referring +/// back to them +struct ConstraintPosition + : public PredicateBase, + Predicates::ConstraintResultPos> { + using PredicateBase::PredicateBase; + + ConstraintQuestion *getQuestion() const { return key.first; } + + unsigned getIndex() const { return key.second; } +}; + //===----------------------------------------------------------------------===// // ForEachPosition @@ -443,11 +463,13 @@ struct AttributeQuestion : public PredicateBase {}; -/// Apply a parameterized constraint to multiple position values. +/// Apply a parameterized constraint to multiple position values and possibly +/// produce results. struct ConstraintQuestion - : public PredicateBase>, - Predicates::ConstraintQuestion> { + : public PredicateBase< + ConstraintQuestion, Qualifier, + std::tuple, ArrayRef>, + Predicates::ConstraintQuestion> { using Base::Base; /// Return the name of the constraint. @@ -456,11 +478,15 @@ struct ConstraintQuestion /// Return the arguments of the constraint. ArrayRef getArgs() const { return std::get<1>(key); } + /// Return the result types of the constraint. + ArrayRef getResultTypes() const { return std::get<2>(key); } + /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), - alloc.copyInto(std::get<1>(key))}); + alloc.copyInto(std::get<1>(key)), + alloc.copyInto(std::get<2>(key))}); } }; @@ -513,6 +539,7 @@ class PredicateUniquer : public StorageUniquer { // Register the types of Positions with the uniquer. registerParametricStorageType(); registerParametricStorageType(); + registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); @@ -575,6 +602,12 @@ class PredicateBuilder { return OperationPosition::get(uniquer, p); } + // Returns a position for a new value created by a constraint. + ConstraintPosition *getConstraintPosition(ConstraintQuestion *q, + unsigned index) { + return ConstraintPosition::get(uniquer, std::make_pair(q, index)); + } + /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); @@ -660,8 +693,10 @@ class PredicateBuilder { } /// Create a predicate that applies a generic constraint. - Predicate getConstraint(StringRef name, ArrayRef pos) { - return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)), + Predicate getConstraint(StringRef name, ArrayRef args, + ArrayRef resultTypes) { + return {ConstraintQuestion::get(uniquer, + std::make_tuple(name, args, resultTypes)), TrueAnswer::get(uniquer)}; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 422182e9702421..68b70f8d575e29 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -272,8 +272,16 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, // Push the constraint to the furthest position. Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), comparePosDepth); - PredicateBuilder::Predicate pred = - builder.getConstraint(op.getName(), allPositions); + ResultRange results = op.getResults(); + PredicateBuilder::Predicate pred = builder.getConstraint( + op.getName(), allPositions, SmallVector(results.getTypes())); + + // for each result register a position so it can be used later + for (auto result : llvm::enumerate(results)) { + ConstraintQuestion *q = cast(pred.first); + ConstraintPosition *pos = builder.getConstraintPosition(q, result.index()); + inputs[result.value()] = pos; + } predList.emplace_back(pos, pred); } From bea4b5736caaf1a1807dea42f42001faede189f9 Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Tue, 18 Apr 2023 08:25:49 +0100 Subject: [PATCH 07/12] PDL Bytecode generator + interpreter: Added support for constraints with results --- mlir/lib/Rewrite/ByteCode.cpp | 47 ++++++++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index b4a57f35370167..3264bbf1d9fd03 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -768,10 +768,27 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) { void Generator::generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer) { - assert(constraintToMemIndex.count(op.getName()) && - "expected index for constraint function"); - writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); + /// Constraints that should return a value have to be registered as rewrites + /// If the constraint and rewrite of similar name are registered the + /// constraint fun takes precedence + ResultRange results = op.getResults(); + if (results.size() == 0 && constraintToMemIndex.count(op.getName()) != 0) { + writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); + } else if (results.size() > 0 && + externalRewriterToMemIndex.count(op.getName()) != 0) { + writer.append(OpCode::ApplyConstraint, + externalRewriterToMemIndex[op.getName()]); + } else { + assert(true && "expected index for constraint function, make sure it is " + "registered properly. Note that native constraints with " + "results have to be registered as native rewriters."); + } writer.appendPDLValueList(op.getArgs()); + writer.append(ByteCodeField(results.size())); + for (Value result : results) { + // TODO: Handle result ranges + writer.append(result); + } writer.append(op.getSuccessors()); } void Generator::generate(pdl_interp::ApplyRewriteOp op, @@ -1405,7 +1422,7 @@ class ByteCodeRewriteResultList : public PDLResultList { void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); - const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; + ByteCodeField fun_idx = read(); SmallVector args; readList(args); @@ -1414,8 +1431,26 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { llvm::interleaveComma(args, llvm::dbgs()); }); - // Invoke the constraint and jump to the proper destination. - selectJump(succeeded(constraintFn(rewriter, args))); + ByteCodeField numResults = read(); + if (numResults == 0) { + const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx]; + LogicalResult rewriteResult = constraintFn(rewriter, args); + // Depending on the constraint jump to the proper destination. + selectJump(succeeded(rewriteResult)); + } else { + const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx]; + ByteCodeRewriteResultList results(numResults); + LogicalResult rewriteResult = constraintFn(rewriter, results, args); + assert(results.getResults().size() == numResults && + "native PDL rewrite function returned unexpected number of results"); + + for (PDLValue &result : results.getResults()) { + LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); + memory[read()] = result.getAsOpaquePointer(); + } + // Depending on the constraint jump to the proper destination. + selectJump(succeeded(rewriteResult)); + } } LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { From 5b4fa7b6b6743033e4b570692c694c831c43353f Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Thu, 20 Apr 2023 10:42:16 +0100 Subject: [PATCH 08/12] Fix for constraints with unused results --- .../Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp | 11 ++++++----- .../PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index 5b99894e91a9b3..2c08f3483454ac 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -473,11 +473,12 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, for (auto result : llvm::enumerate(applyConstraintOp.getResults())) { std::pair substitutionKey = { cstQuestion, result.index()}; - assert( - substitutions.count(substitutionKey) && - "expected a placeholder value for a native constraint with results"); - substitutions[substitutionKey].replaceAllUsesWith(result.value()); - substitutions[substitutionKey].getDefiningOp()->erase(); + // Check if there are substitutions to perform. If the result is never + // used no substitutions will have been generated. + if (substitutions.count(substitutionKey)) { + substitutions[substitutionKey].replaceAllUsesWith(result.value()); + substitutions[substitutionKey].getDefiningOp()->erase(); + } } break; } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index 423d9183e222a4..14445beadef304 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -93,6 +93,20 @@ module @constraint_with_result { // ----- +// CHECK-LABEL: module @constraint_with_unused_result +module @constraint_with_unused_result { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter" + } +} + +// ----- + // CHECK-LABEL: module @inputs module @inputs { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) From 050190d2f522132511633f1f0f5227ac9e60d25e Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Tue, 25 Apr 2023 11:06:26 +0100 Subject: [PATCH 09/12] Add registry specific to native constraints with results --- mlir/include/mlir/IR/PatternMatch.h | 27 +++++++++++++++++++++++++++ mlir/lib/IR/PatternMatch.cpp | 9 +++++++++ 2 files changed, 36 insertions(+) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e257b67ad9d8ef..dca6bc1a6b3701 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -1457,6 +1457,33 @@ class PDLPatternModule { std::forward(constraintFn))); } + /// Register a constraint function that produces results with PDL. A + /// constraint function with results uses the same type and registry as + /// rewrite functions. It may be specified in one of two ways: + /// + /// * `void (PatternRewriter &, PDLResultList &, ArrayRef)` + /// + /// In this overload the arguments of the constraint function are passed via + /// the low-level PDLValue form, and the results are manually appended to + /// the given result list. + /// + /// * `ResultT (PatternRewriter &, ValueTs... values)` + /// + /// In this form the arguments and result of the constraint function are + /// passed via the expected high level C++ type. In this form, the framework + /// will automatically unwrap the PDLValues arguments and convert them to + /// the expected ValueTs. It will also automatically handle the processing + /// and packaging of the result value to the result list. For more + /// information see the registering of rewrite functions below. + void registerConstraintFunctionWithResults(StringRef name, + PDLRewriteFunction constraintFn); + template + void registerConstraintFunctionWithResults(StringRef name, + RewriteFnT &&constraintFn) { + registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn( + std::forward(constraintFn))); + } + /// Register a rewrite function with PDL. A rewrite function may be specified /// in one of two ways: /// diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index d2de65e7694bab..588ece6c506b2d 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -204,6 +204,15 @@ void PDLPatternModule::registerConstraintFunction( constraintFunctions.try_emplace(name, std::move(constraintFn)); } +void PDLPatternModule::registerConstraintFunctionWithResults( + StringRef name, PDLRewriteFunction constraintFn) { + // TODO: Is it possible to diagnose when `name` is already registered to + // a function that is not equivalent to `rewriteFn`? + // Allow existing mappings in the case multiple patterns depend on the same + // rewrite. + registerRewriteFunction(name, std::move(constraintFn)); +} + void PDLPatternModule::registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) { // TODO: Is it possible to diagnose when `name` is already registered to From 8f3c8810f949612cc4d2e98ad4cc1ecfcc0a79ab Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Wed, 26 Apr 2023 10:26:49 +0100 Subject: [PATCH 10/12] Adjust assert message for unregistered constraints --- mlir/lib/Rewrite/ByteCode.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 3264bbf1d9fd03..03024659fafef3 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -781,7 +781,8 @@ void Generator::generate(pdl_interp::ApplyConstraintOp op, } else { assert(true && "expected index for constraint function, make sure it is " "registered properly. Note that native constraints with " - "results have to be registered as native rewriters."); + "results have to be registered using " + "PDLPatternModule::registerConstraintFunctionWithResults."); } writer.appendPDLValueList(op.getArgs()); writer.append(ByteCodeField(results.size())); From 9b3272715b38cef34078916d95a24ade2cead23e Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Wed, 26 Apr 2023 10:37:23 +0100 Subject: [PATCH 11/12] constrain the allowed type of constraint functions with results --- mlir/include/mlir/IR/PatternMatch.h | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index dca6bc1a6b3701..86d74ad277f329 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -1459,30 +1459,17 @@ class PDLPatternModule { /// Register a constraint function that produces results with PDL. A /// constraint function with results uses the same type and registry as - /// rewrite functions. It may be specified in one of two ways: + /// rewrite functions. It may be specified as follows: /// - /// * `void (PatternRewriter &, PDLResultList &, ArrayRef)` + /// * `LogicalResult (PatternRewriter &, PDLResultList &, + /// ArrayRef)` /// /// In this overload the arguments of the constraint function are passed via /// the low-level PDLValue form, and the results are manually appended to /// the given result list. /// - /// * `ResultT (PatternRewriter &, ValueTs... values)` - /// - /// In this form the arguments and result of the constraint function are - /// passed via the expected high level C++ type. In this form, the framework - /// will automatically unwrap the PDLValues arguments and convert them to - /// the expected ValueTs. It will also automatically handle the processing - /// and packaging of the result value to the result list. For more - /// information see the registering of rewrite functions below. void registerConstraintFunctionWithResults(StringRef name, PDLRewriteFunction constraintFn); - template - void registerConstraintFunctionWithResults(StringRef name, - RewriteFnT &&constraintFn) { - registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn( - std::forward(constraintFn))); - } /// Register a rewrite function with PDL. A rewrite function may be specified /// in one of two ways: From b98b75e149dd4f2113108a5765d580927238b85d Mon Sep 17 00:00:00 2001 From: Martin Paul Luecke Date: Wed, 26 Apr 2023 10:39:07 +0100 Subject: [PATCH 12/12] adjust comments --- mlir/include/mlir/IR/PatternMatch.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 86d74ad277f329..b99313a44efa1d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -1458,7 +1458,7 @@ class PDLPatternModule { } /// Register a constraint function that produces results with PDL. A - /// constraint function with results uses the same type and registry as + /// constraint function with results uses the same registry as /// rewrite functions. It may be specified as follows: /// /// * `LogicalResult (PatternRewriter &, PDLResultList &,