Skip to content

Commit

Permalink
FXML.1923: PDLL support for native constraints with attribute results (
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-luecke authored and Ferdinand Lemaire committed May 4, 2023
1 parent 4325a06 commit bd44b77
Show file tree
Hide file tree
Showing 15 changed files with 255 additions and 33 deletions.
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def PDL_ApplyNativeConstraintOp
let description = [{
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
has been registered externally with the consumer of PDL, to a given set of
entities.
entities and optionally return a number of values..

Example:

Expand All @@ -46,7 +46,8 @@ def PDL_ApplyNativeConstraintOp
}];

let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict";
let hasVerifier = 1;
}

Expand Down
8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let description = [{
`pdl_interp.apply_constraint` operations apply a generic constraint, that
has been registered with the interpreter, with a given set of positional
values. On success, this operation branches to the true destination,
otherwise the false destination is taken.
values. The constraint function may return any number of results.
On success, this operation branches to the true destination, otherwise
the false destination is taken.

Example:

Expand All @@ -101,8 +102,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
}];

let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
$name `(` $args `:` type($args) `)` attr-dict `->` successors
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors
}];
}

Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,20 @@ class PDLPatternModule {
std::forward<ConstraintFnT>(constraintFn)));
}

/// Register a constraint function that produces results with PDL. A
/// constraint function with results uses the same registry as
/// rewrite functions. It may be specified as follows:
///
/// * `LogicalResult (PatternRewriter &, PDLResultList &,
/// ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form, and the results are manually appended to
/// the given result list.
///
void registerConstraintFunctionWithResults(StringRef name,
PDLRewriteFunction constraintFn);

/// Register a rewrite function with PDL. A rewrite function may be specified
/// in one of two ways:
///
Expand Down
35 changes: 33 additions & 2 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ struct PatternLowering {
/// A mapping between pattern operations and the corresponding configuration
/// set.
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;

/// A mapping between constraint questions that refer to values created by
/// constraints and the temporary placeholder values created for them.
DenseMap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
};
} // namespace

Expand Down Expand Up @@ -364,6 +368,20 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
loc, rawTypeAttr.cast<ArrayAttr>());
break;
}
case Predicates::ConstraintResultPos: {
// At this point in time the corresponding pdl.ApplyNativeConstraint op has
// been deleted and the new pdl_interp.ApplyConstraint has not been created
// yet. To enable use of results created by these operations we build a
// placeholder value that will be replaced when the actual
// pdl_interp.ApplyConstraint operation is created.
auto *constrResPos = cast<ConstraintPosition>(pos);
Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
loc, StringAttr::get(builder.getContext(), "placeholder"));
substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] =
placeholderValue;
value = placeholderValue;
break;
}
default:
llvm_unreachable("Generating unknown Position getter");
break;
Expand Down Expand Up @@ -447,8 +465,21 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
args, success, failure);
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
success, failure);
// Replace the generated placeholders with the results of the constraint and
// erase them
for (auto result : llvm::enumerate(applyConstraintOp.getResults())) {
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
cstQuestion, result.index()};
// Check if there are substitutions to perform. If the result is never
// used no substitutions will have been generated.
if (substitutions.count(substitutionKey)) {
substitutions[substitutionKey].replaceAllUsesWith(result.value());
substitutions[substitutionKey].getDefiningOp()->erase();
}
}
break;
}
default:
Expand Down
49 changes: 42 additions & 7 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ enum Kind : unsigned {
OperandPos,
OperandGroupPos,
AttributePos,
ConstraintResultPos,
ResultPos,
ResultGroupPos,
TypePos,
Expand Down Expand Up @@ -187,6 +188,25 @@ struct AttributeLiteralPosition
using PredicateBase::PredicateBase;
};

//===----------------------------------------------------------------------===//
// ConstraintPosition

struct ConstraintQuestion;

/// A position describing the result of a native constraint. It saves the
/// corresponding ConstraintQuestion and result index to enable referring
/// back to them
struct ConstraintPosition
: public PredicateBase<ConstraintPosition, Position,
std::pair<ConstraintQuestion *, unsigned>,
Predicates::ConstraintResultPos> {
using PredicateBase::PredicateBase;

ConstraintQuestion *getQuestion() const { return key.first; }

unsigned getIndex() const { return key.second; }
};

//===----------------------------------------------------------------------===//
// ForEachPosition

Expand Down Expand Up @@ -447,11 +467,13 @@ struct AttributeQuestion
: public PredicateBase<AttributeQuestion, Qualifier, void,
Predicates::AttributeQuestion> {};

/// Apply a parameterized constraint to multiple position values.
/// Apply a parameterized constraint to multiple position values and possibly
/// produce results.
struct ConstraintQuestion
: public PredicateBase<ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>>,
Predicates::ConstraintQuestion> {
: public PredicateBase<
ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>>,
Predicates::ConstraintQuestion> {
using Base::Base;

/// Return the name of the constraint.
Expand All @@ -460,11 +482,15 @@ struct ConstraintQuestion
/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }

/// Return the result types of the constraint.
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }

/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key))});
alloc.copyInto(std::get<1>(key)),
alloc.copyInto(std::get<2>(key))});
}
};

Expand Down Expand Up @@ -517,6 +543,7 @@ class PredicateUniquer : public StorageUniquer {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<AttributeLiteralPosition>();
registerParametricStorageType<ConstraintPosition>();
registerParametricStorageType<ForEachPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
Expand Down Expand Up @@ -579,6 +606,12 @@ class PredicateBuilder {
return OperationPosition::get(uniquer, p);
}

// Returns a position for a new value created by a constraint.
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
unsigned index) {
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
}

/// Returns an attribute position for an attribute of the given operation.
Position *getAttribute(OperationPosition *p, StringRef name) {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
Expand Down Expand Up @@ -664,8 +697,10 @@ class PredicateBuilder {
}

/// Create a predicate that applies a generic constraint.
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos) {
return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)),
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
ArrayRef<Type> resultTypes) {
return {ConstraintQuestion::get(uniquer,
std::make_tuple(name, args, resultTypes)),
TrueAnswer::get(uniquer)};
}

Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,16 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
PredicateBuilder::Predicate pred =
builder.getConstraint(op.getName(), allPositions);
ResultRange results = op.getResults();
PredicateBuilder::Predicate pred = builder.getConstraint(
op.getName(), allPositions, SmallVector<Type>(results.getTypes()));

// for each result register a position so it can be used later
for (auto result : llvm::enumerate(results)) {
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
ConstraintPosition *pos = builder.getConstraintPosition(q, result.index());
inputs[result.value()] = pos;
}
predList.emplace_back(pos, pred);
}

Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ void PDLPatternModule::registerConstraintFunction(
constraintFunctions.try_emplace(name, std::move(constraintFn));
}

void PDLPatternModule::registerConstraintFunctionWithResults(
StringRef name, PDLRewriteFunction constraintFn) {
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `rewriteFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// rewrite.
registerRewriteFunction(name, std::move(constraintFn));
}

void PDLPatternModule::registerRewriteFunction(StringRef name,
PDLRewriteFunction rewriteFn) {
// TODO: Is it possible to diagnose when `name` is already registered to
Expand Down
48 changes: 42 additions & 6 deletions mlir/lib/Rewrite/ByteCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,28 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {

void Generator::generate(pdl_interp::ApplyConstraintOp op,
ByteCodeWriter &writer) {
assert(constraintToMemIndex.count(op.getName()) &&
"expected index for constraint function");
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
/// Constraints that should return a value have to be registered as rewrites
/// If the constraint and rewrite of similar name are registered the
/// constraint fun takes precedence
ResultRange results = op.getResults();
if (results.size() == 0 && constraintToMemIndex.count(op.getName()) != 0) {
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
} else if (results.size() > 0 &&
externalRewriterToMemIndex.count(op.getName()) != 0) {
writer.append(OpCode::ApplyConstraint,
externalRewriterToMemIndex[op.getName()]);
} else {
assert(true && "expected index for constraint function, make sure it is "
"registered properly. Note that native constraints with "
"results have to be registered using "
"PDLPatternModule::registerConstraintFunctionWithResults.");
}
writer.appendPDLValueList(op.getArgs());
writer.append(ByteCodeField(results.size()));
for (Value result : results) {
// TODO: Handle result ranges
writer.append(result);
}
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
Expand Down Expand Up @@ -1406,7 +1424,7 @@ class ByteCodeRewriteResultList : public PDLResultList {

void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
ByteCodeField fun_idx = read();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);

Expand All @@ -1415,8 +1433,26 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
llvm::interleaveComma(args, llvm::dbgs());
});

// Invoke the constraint and jump to the proper destination.
selectJump(succeeded(constraintFn(rewriter, args)));
ByteCodeField numResults = read();
if (numResults == 0) {
const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx];
LogicalResult rewriteResult = constraintFn(rewriter, args);
// Depending on the constraint jump to the proper destination.
selectJump(succeeded(rewriteResult));
} else {
const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx];
ByteCodeRewriteResultList results(numResults);
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");

for (PDLValue &result : results.getResults()) {
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
memory[read()] = result.getAsOpaquePointer();
}
// Depending on the constraint jump to the proper destination.
selectJump(succeeded(rewriteResult));
}
}

LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
Expand Down
6 changes: 0 additions & 6 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1359,12 +1359,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
if (failed(parseToken(Token::semicolon,
"expected `;` after native declaration")))
return failure();
// TODO: PDL should be able to support constraint results in certain
// situations, we should revise this.
if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
return emitError(
"native Constraints currently do not support returning results");
}
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
}

Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,34 @@ module @constraints {

// -----

// CHECK-LABEL: module @constraint_with_result
module @constraint_with_result {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"(%attr : !pdl.attribute)
}
}

// -----

// CHECK-LABEL: module @constraint_with_unused_result
module @constraint_with_unused_result {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @inputs
module @inputs {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/PDL/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {

// -----

pdl.pattern @apply_constraint_with_no_results : benefit(1) {
%root = operation
apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
rewrite %root with "rewriter"
}

// -----

pdl.pattern @apply_constraint_with_results : benefit(1) {
%root = operation
%attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
rewrite %root {
apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
}
}

// -----

pdl.pattern @attribute_with_dict : benefit(1) {
%root = operation
rewrite %root {
Expand Down
Loading

0 comments on commit bd44b77

Please sign in to comment.