diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index d5b5d2be1c6a31..c98ffbcdeb9116 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -35,7 +35,7 @@ def PDL_ApplyNativeConstraintOp let description = [{ `pdl.apply_native_constraint` operations apply a native C++ constraint, that has been registered externally with the consumer of PDL, to a given set of - entities. + entities and optionally return a number of values.. Example: @@ -46,7 +46,8 @@ def PDL_ApplyNativeConstraintOp }]; let arguments = (ins StrAttr:$name, Variadic:$args); - let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict"; + let results = (outs Variadic:$results); + let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict"; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 47c96ac25bca29..3cc25876dc0a9e 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -88,8 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { let description = [{ `pdl_interp.apply_constraint` operations apply a generic constraint, that has been registered with the interpreter, with a given set of positional - values. On success, this operation branches to the true destination, - otherwise the false destination is taken. + values. The constraint function may return any number of results. + On success, this operation branches to the true destination, otherwise + the false destination is taken. Example: @@ -101,8 +102,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { }]; let arguments = (ins StrAttr:$name, Variadic:$args); + let results = (outs Variadic:$results); let assemblyFormat = [{ - $name `(` $args `:` type($args) `)` attr-dict `->` successors + $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors }]; } diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 600ace48827346..7bc1d514b14c3d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -1525,6 +1525,20 @@ class PDLPatternModule { std::forward(constraintFn))); } + /// Register a constraint function that produces results with PDL. A + /// constraint function with results uses the same registry as + /// rewrite functions. It may be specified as follows: + /// + /// * `LogicalResult (PatternRewriter &, PDLResultList &, + /// ArrayRef)` + /// + /// In this overload the arguments of the constraint function are passed via + /// the low-level PDLValue form, and the results are manually appended to + /// the given result list. + /// + void registerConstraintFunctionWithResults(StringRef name, + PDLRewriteFunction constraintFn); + /// Register a rewrite function with PDL. A rewrite function may be specified /// in one of two ways: /// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index fc0c845b69987e..ed6863eb9d8511 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -148,6 +148,10 @@ struct PatternLowering { /// A mapping between pattern operations and the corresponding configuration /// set. DenseMap *configMap; + + /// A mapping between constraint questions that refer to values created by + /// constraints and the temporary placeholder values created for them. + DenseMap, Value> substitutions; }; } // namespace @@ -364,6 +368,20 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { loc, rawTypeAttr.cast()); break; } + case Predicates::ConstraintResultPos: { + // At this point in time the corresponding pdl.ApplyNativeConstraint op has + // been deleted and the new pdl_interp.ApplyConstraint has not been created + // yet. To enable use of results created by these operations we build a + // placeholder value that will be replaced when the actual + // pdl_interp.ApplyConstraint operation is created. + auto *constrResPos = cast(pos); + Value placeholderValue = builder.create( + loc, StringAttr::get(builder.getContext(), "placeholder")); + substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] = + placeholderValue; + value = placeholderValue; + break; + } default: llvm_unreachable("Generating unknown Position getter"); break; @@ -447,8 +465,21 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast(question); - builder.create(loc, cstQuestion->getName(), - args, success, failure); + auto applyConstraintOp = builder.create( + loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, + success, failure); + // Replace the generated placeholders with the results of the constraint and + // erase them + for (auto result : llvm::enumerate(applyConstraintOp.getResults())) { + std::pair substitutionKey = { + cstQuestion, result.index()}; + // Check if there are substitutions to perform. If the result is never + // used no substitutions will have been generated. + if (substitutions.count(substitutionKey)) { + substitutions[substitutionKey].replaceAllUsesWith(result.value()); + substitutions[substitutionKey].getDefiningOp()->erase(); + } + } break; } default: diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h index 1027ed00757ce3..20304bd35dc28b 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -47,6 +47,7 @@ enum Kind : unsigned { OperandPos, OperandGroupPos, AttributePos, + ConstraintResultPos, ResultPos, ResultGroupPos, TypePos, @@ -187,6 +188,25 @@ struct AttributeLiteralPosition using PredicateBase::PredicateBase; }; +//===----------------------------------------------------------------------===// +// ConstraintPosition + +struct ConstraintQuestion; + +/// A position describing the result of a native constraint. It saves the +/// corresponding ConstraintQuestion and result index to enable referring +/// back to them +struct ConstraintPosition + : public PredicateBase, + Predicates::ConstraintResultPos> { + using PredicateBase::PredicateBase; + + ConstraintQuestion *getQuestion() const { return key.first; } + + unsigned getIndex() const { return key.second; } +}; + //===----------------------------------------------------------------------===// // ForEachPosition @@ -447,11 +467,13 @@ struct AttributeQuestion : public PredicateBase {}; -/// Apply a parameterized constraint to multiple position values. +/// Apply a parameterized constraint to multiple position values and possibly +/// produce results. struct ConstraintQuestion - : public PredicateBase>, - Predicates::ConstraintQuestion> { + : public PredicateBase< + ConstraintQuestion, Qualifier, + std::tuple, ArrayRef>, + Predicates::ConstraintQuestion> { using Base::Base; /// Return the name of the constraint. @@ -460,11 +482,15 @@ struct ConstraintQuestion /// Return the arguments of the constraint. ArrayRef getArgs() const { return std::get<1>(key); } + /// Return the result types of the constraint. + ArrayRef getResultTypes() const { return std::get<2>(key); } + /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), - alloc.copyInto(std::get<1>(key))}); + alloc.copyInto(std::get<1>(key)), + alloc.copyInto(std::get<2>(key))}); } }; @@ -517,6 +543,7 @@ class PredicateUniquer : public StorageUniquer { // Register the types of Positions with the uniquer. registerParametricStorageType(); registerParametricStorageType(); + registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); @@ -579,6 +606,12 @@ class PredicateBuilder { return OperationPosition::get(uniquer, p); } + // Returns a position for a new value created by a constraint. + ConstraintPosition *getConstraintPosition(ConstraintQuestion *q, + unsigned index) { + return ConstraintPosition::get(uniquer, std::make_pair(q, index)); + } + /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); @@ -664,8 +697,10 @@ class PredicateBuilder { } /// Create a predicate that applies a generic constraint. - Predicate getConstraint(StringRef name, ArrayRef pos) { - return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)), + Predicate getConstraint(StringRef name, ArrayRef args, + ArrayRef resultTypes) { + return {ConstraintQuestion::get(uniquer, + std::make_tuple(name, args, resultTypes)), TrueAnswer::get(uniquer)}; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 034291440ad2c3..1d04e427c0167f 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -272,8 +272,16 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, // Push the constraint to the furthest position. Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), comparePosDepth); - PredicateBuilder::Predicate pred = - builder.getConstraint(op.getName(), allPositions); + ResultRange results = op.getResults(); + PredicateBuilder::Predicate pred = builder.getConstraint( + op.getName(), allPositions, SmallVector(results.getTypes())); + + // for each result register a position so it can be used later + for (auto result : llvm::enumerate(results)) { + ConstraintQuestion *q = cast(pred.first); + ConstraintPosition *pos = builder.getConstraintPosition(q, result.index()); + inputs[result.value()] = pos; + } predList.emplace_back(pos, pred); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 052696d5cb13a6..46b5c1e6852de6 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -204,6 +204,15 @@ void PDLPatternModule::registerConstraintFunction( constraintFunctions.try_emplace(name, std::move(constraintFn)); } +void PDLPatternModule::registerConstraintFunctionWithResults( + StringRef name, PDLRewriteFunction constraintFn) { + // TODO: Is it possible to diagnose when `name` is already registered to + // a function that is not equivalent to `rewriteFn`? + // Allow existing mappings in the case multiple patterns depend on the same + // rewrite. + registerRewriteFunction(name, std::move(constraintFn)); +} + void PDLPatternModule::registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) { // TODO: Is it possible to diagnose when `name` is already registered to diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 1ea4bef5402a74..516784f274cc04 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -769,10 +769,28 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) { void Generator::generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer) { - assert(constraintToMemIndex.count(op.getName()) && - "expected index for constraint function"); - writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); + /// Constraints that should return a value have to be registered as rewrites + /// If the constraint and rewrite of similar name are registered the + /// constraint fun takes precedence + ResultRange results = op.getResults(); + if (results.size() == 0 && constraintToMemIndex.count(op.getName()) != 0) { + writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); + } else if (results.size() > 0 && + externalRewriterToMemIndex.count(op.getName()) != 0) { + writer.append(OpCode::ApplyConstraint, + externalRewriterToMemIndex[op.getName()]); + } else { + assert(true && "expected index for constraint function, make sure it is " + "registered properly. Note that native constraints with " + "results have to be registered using " + "PDLPatternModule::registerConstraintFunctionWithResults."); + } writer.appendPDLValueList(op.getArgs()); + writer.append(ByteCodeField(results.size())); + for (Value result : results) { + // TODO: Handle result ranges + writer.append(result); + } writer.append(op.getSuccessors()); } void Generator::generate(pdl_interp::ApplyRewriteOp op, @@ -1406,7 +1424,7 @@ class ByteCodeRewriteResultList : public PDLResultList { void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); - const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; + ByteCodeField fun_idx = read(); SmallVector args; readList(args); @@ -1415,8 +1433,26 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { llvm::interleaveComma(args, llvm::dbgs()); }); - // Invoke the constraint and jump to the proper destination. - selectJump(succeeded(constraintFn(rewriter, args))); + ByteCodeField numResults = read(); + if (numResults == 0) { + const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx]; + LogicalResult rewriteResult = constraintFn(rewriter, args); + // Depending on the constraint jump to the proper destination. + selectJump(succeeded(rewriteResult)); + } else { + const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx]; + ByteCodeRewriteResultList results(numResults); + LogicalResult rewriteResult = constraintFn(rewriter, results, args); + assert(results.getResults().size() == numResults && + "native PDL rewrite function returned unexpected number of results"); + + for (PDLValue &result : results.getResults()) { + LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); + memory[read()] = result.getAsOpaquePointer(); + } + // Depending on the constraint jump to the proper destination. + selectJump(succeeded(rewriteResult)); + } } LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 044aa612b67fbb..67168f6177e841 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -1359,12 +1359,6 @@ FailureOr Parser::parseUserNativeConstraintOrRewriteDecl( if (failed(parseToken(Token::semicolon, "expected `;` after native declaration"))) return failure(); - // TODO: PDL should be able to support constraint results in certain - // situations, we should revise this. - if (std::is_same::value && !results.empty()) { - return emitError( - "native Constraints currently do not support returning results"); - } return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index b94451c4a08689..14445beadef304 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -79,6 +79,34 @@ module @constraints { // ----- +// CHECK-LABEL: module @constraint_with_result +module @constraint_with_result { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } +} + +// ----- + +// CHECK-LABEL: module @constraint_with_unused_result +module @constraint_with_unused_result { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter" + } +} + +// ----- + // CHECK-LABEL: module @inputs module @inputs { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir index 6e6da5cce446ae..20e40deea5f863 100644 --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) { // ----- +pdl.pattern @apply_constraint_with_no_results : benefit(1) { + %root = operation + apply_native_constraint "NativeConstraint"(%root : !pdl.operation) + rewrite %root with "rewriter" +} + +// ----- + +pdl.pattern @apply_constraint_with_results : benefit(1) { + %root = operation + %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute + rewrite %root { + apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute) + } +} + +// ----- + pdl.pattern @attribute_with_dict : benefit(1) { %root = operation rewrite %root { diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index 57bec8ce370736..9d7166243f1cd5 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -72,6 +72,37 @@ module @ir attributes { test.apply_constraint_2 } { // ----- +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^bb0, ^end + + ^bb0: + %attr = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root, %attr : !pdl.operation, !pdl.attribute) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %attr : !pdl.attribute) { + %op = pdl_interp.create_operation "test.success" {"attr" = %attr} + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_3 +// CHECK: "test.success"() {attr = "test.success"} +module @ir attributes { test.apply_constraint_3 } { + "test.op"() : () -> () +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::ApplyRewriteOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index daa1c371f27c92..6c86c4befbb001 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -30,6 +30,15 @@ static LogicalResult customMultiEntityVariadicConstraint( return success(); } +// Custom constraint that returns a value +static LogicalResult customResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + StringAttr customAttr = rewriter.getStringAttr("test.success"); + results.push_back(customAttr); + return success(); +} + // Custom creator invoked from PDL. static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { return rewriter.create(OperationState(op->getLoc(), "test.success")); @@ -102,6 +111,9 @@ struct TestPDLByteCodePass customMultiEntityConstraint); pdlPattern.registerConstraintFunction("multi_entity_var_constraint", customMultiEntityVariadicConstraint); + pdlPattern.registerRewriteFunction("check_op_and_get_attr_constr", + customResultConstraint); + pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("var_creator", customVariadicResultCreate); diff --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll index 18877b4bcc50ec..48747d3fa2e681 100644 --- a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll @@ -158,8 +158,3 @@ Pattern { // CHECK: expected `;` after native declaration Constraint Foo() [{}] - -// ----- - -// CHECK: native Constraints currently do not support returning results -Constraint Foo() -> Op; diff --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll index 1c0a015ab4a7b4..e2a52ff130cc84 100644 --- a/mlir/test/mlir-pdll/Parser/constraint.pdll +++ b/mlir/test/mlir-pdll/Parser/constraint.pdll @@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }]; // ----- +// Test that native constraints support returning results. + +// CHECK: Module +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType +Constraint Foo() -> Attr; + +// ----- + // CHECK: Module // CHECK: `-UserConstraintDecl {{.*}} Name ResultType // CHECK: `Inputs`