Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FXML.1923: PDLL support for native constraints with attribute results #24

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def PDL_ApplyNativeConstraintOp
let description = [{
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
has been registered externally with the consumer of PDL, to a given set of
entities.
entities and optionally return a number of values..

Example:

Expand All @@ -46,7 +46,8 @@ def PDL_ApplyNativeConstraintOp
}];

let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict";
let hasVerifier = 1;
}

Expand Down
8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let description = [{
`pdl_interp.apply_constraint` operations apply a generic constraint, that
has been registered with the interpreter, with a given set of positional
values. On success, this operation branches to the true destination,
otherwise the false destination is taken.
values. The constraint function may return any number of results.
On success, this operation branches to the true destination, otherwise
the false destination is taken.

Example:

Expand All @@ -101,8 +102,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
}];

let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
$name `(` $args `:` type($args) `)` attr-dict `->` successors
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors
}];
}

Expand Down
35 changes: 33 additions & 2 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ struct PatternLowering {
/// A mapping between pattern operations and the corresponding configuration
/// set.
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;

/// A mapping between constraint questions that refer to values created by
/// constraints and the temporary placeholder values created for them.
DenseMap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
};
} // namespace

Expand Down Expand Up @@ -364,6 +368,20 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
loc, rawTypeAttr.cast<ArrayAttr>());
break;
}
case Predicates::ConstraintResultPos: {
// At this point in time the corresponding pdl.ApplyNativeConstraint op has
// been deleted and the new pdl_interp.ApplyConstraint has not been created
// yet. To enable use of results created by these operations we build a
// placeholder value that will be replaced when the actual
// pdl_interp.ApplyConstraint operation is created.
Comment on lines +372 to +376

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know enough about PDLInterp, could you give me an example when this is relevant?

Copy link
Author

@martin-luecke martin-luecke Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my description on top:

This is required in all scenarios where a native constraint returns a result that is used on the rhs
because said result has to be an input the the pdl_interp rewriter function that is created for the pattern.
However, the call to this function is always generated before the pdl_interp.apply_constraint operation is created.

An example for the output of this pass is the following:
PDL input:

module {
  pdl.pattern @simple : benefit(1) {
    %0 = operands
    %1 = operation "test.op"(%0 : !pdl.range<value>) 
    %attr = apply_native_constraint "NativeConstraint"(%1 : !pdl.operation) : !pdl.attribute 
    rewrite %1 {
      %3 = operation "test.success"  {"someAttr" = %attr}
      replace %1 with %3
    }
  }
}

The PDL_interp output is the following:

module {
  pdl_interp.func @matcher(%arg0: !pdl.operation) {
    pdl_interp.check_operation_name of %arg0 is "test.op" -> ^bb2, ^bb1
  ^bb1:  // 4 preds: ^bb0, ^bb2, ^bb3, ^bb4
    pdl_interp.finalize
  ^bb2:  // pred: ^bb0
    pdl_interp.check_result_count of %arg0 is 0 -> ^bb3, ^bb1
  ^bb3:  // pred: ^bb2
    %0 = pdl_interp.apply_constraint "NativeConstraint"(%arg0 : !pdl.operation) : !pdl.attribute -> ^bb4, ^bb1
  ^bb4:  // pred: ^bb3
    pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%0, %arg0 : !pdl.attribute, !pdl.operation) : benefit(1), generatedOps(["test.success"]), loc([%arg0]), root("test.op") -> ^bb1
  }
  module @rewriters {
    pdl_interp.func @pdl_generated_rewriter(%arg0: !pdl.attribute, %arg1: !pdl.operation) {
      %0 = pdl_interp.get_results of %arg1 : !pdl.range<value>
      %1 = pdl_interp.get_value_type of %0 : !pdl.range<type>
      %2 = pdl_interp.create_operation "test.success" {"someAttr" = %arg0}  -> (%1 : !pdl.range<type>)
      pdl_interp.erase %arg1
      pdl_interp.finalize
    }
  }
}

The pdl_interp.record_match operation and the block it is in is not the last thing this pass generates for a matcher. This is done so the operation that is the last check before a successful match can have this block (^bb4) as successor (e.g. %0 = pdl_interp.apply_constraint here). But the pdl_interp.record_match needs the result of %0 = pdl_interp.apply_constraint as operand.
So there is a cyclic dependency here, which we solve using the placeholder.

In short:

  • %0 = pdl_interp.apply_constraint needs the success branch as successor.
  • pdl_interp.record_match is created with the success branch and needs the result %0 as operand

auto *constrResPos = cast<ConstraintPosition>(pos);
Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
loc, StringAttr::get(builder.getContext(), "placeholder"));
substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] =
placeholderValue;
value = placeholderValue;
break;
}
default:
llvm_unreachable("Generating unknown Position getter");
break;
Expand Down Expand Up @@ -447,8 +465,21 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
args, success, failure);
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
success, failure);
// Replace the generated placeholders with the results of the constraint and
// erase them
for (auto result : llvm::enumerate(applyConstraintOp.getResults())) {
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
cstQuestion, result.index()};
// Check if there are substitutions to perform. If the result is never
// used no substitutions will have been generated.
if (substitutions.count(substitutionKey)) {
substitutions[substitutionKey].replaceAllUsesWith(result.value());
substitutions[substitutionKey].getDefiningOp()->erase();
}
}
break;
}
default:
Expand Down
49 changes: 42 additions & 7 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ enum Kind : unsigned {
OperandPos,
OperandGroupPos,
AttributePos,
ConstraintResultPos,
ResultPos,
ResultGroupPos,
TypePos,
Expand Down Expand Up @@ -187,6 +188,25 @@ struct AttributeLiteralPosition
using PredicateBase::PredicateBase;
};

//===----------------------------------------------------------------------===//
// ConstraintPosition

struct ConstraintQuestion;

/// A position describing the result of a native constraint. It saves the
/// corresponding ConstraintQuestion and result index to enable referring
/// back to them
struct ConstraintPosition
: public PredicateBase<ConstraintPosition, Position,
std::pair<ConstraintQuestion *, unsigned>,
Predicates::ConstraintResultPos> {
using PredicateBase::PredicateBase;

ConstraintQuestion *getQuestion() const { return key.first; }

unsigned getIndex() const { return key.second; }
};

//===----------------------------------------------------------------------===//
// ForEachPosition

Expand Down Expand Up @@ -443,11 +463,13 @@ struct AttributeQuestion
: public PredicateBase<AttributeQuestion, Qualifier, void,
Predicates::AttributeQuestion> {};

/// Apply a parameterized constraint to multiple position values.
/// Apply a parameterized constraint to multiple position values and possibly
/// produce results.
struct ConstraintQuestion
: public PredicateBase<ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>>,
Predicates::ConstraintQuestion> {
: public PredicateBase<
ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>>,
Predicates::ConstraintQuestion> {
using Base::Base;

/// Return the name of the constraint.
Expand All @@ -456,11 +478,15 @@ struct ConstraintQuestion
/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }

/// Return the result types of the constraint.
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }

/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key))});
alloc.copyInto(std::get<1>(key)),
alloc.copyInto(std::get<2>(key))});
}
};

Expand Down Expand Up @@ -513,6 +539,7 @@ class PredicateUniquer : public StorageUniquer {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<AttributeLiteralPosition>();
registerParametricStorageType<ConstraintPosition>();
registerParametricStorageType<ForEachPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
Expand Down Expand Up @@ -575,6 +602,12 @@ class PredicateBuilder {
return OperationPosition::get(uniquer, p);
}

// Returns a position for a new value created by a constraint.
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
unsigned index) {
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
}

/// Returns an attribute position for an attribute of the given operation.
Position *getAttribute(OperationPosition *p, StringRef name) {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
Expand Down Expand Up @@ -660,8 +693,10 @@ class PredicateBuilder {
}

/// Create a predicate that applies a generic constraint.
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos) {
return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)),
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
ArrayRef<Type> resultTypes) {
return {ConstraintQuestion::get(uniquer,
std::make_tuple(name, args, resultTypes)),
TrueAnswer::get(uniquer)};
}

Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,16 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
PredicateBuilder::Predicate pred =
builder.getConstraint(op.getName(), allPositions);
ResultRange results = op.getResults();
PredicateBuilder::Predicate pred = builder.getConstraint(
op.getName(), allPositions, SmallVector<Type>(results.getTypes()));

// for each result register a position so it can be used later
for (auto result : llvm::enumerate(results)) {
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
ConstraintPosition *pos = builder.getConstraintPosition(q, result.index());
inputs[result.value()] = pos;
}
predList.emplace_back(pos, pred);
}

Expand Down
47 changes: 41 additions & 6 deletions mlir/lib/Rewrite/ByteCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,10 +768,27 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {

void Generator::generate(pdl_interp::ApplyConstraintOp op,
ByteCodeWriter &writer) {
assert(constraintToMemIndex.count(op.getName()) &&
"expected index for constraint function");
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
/// Constraints that should return a value have to be registered as rewrites
/// If the constraint and rewrite of similar name are registered the
/// constraint fun takes precedence
ResultRange results = op.getResults();
if (results.size() == 0 && constraintToMemIndex.count(op.getName()) != 0) {
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
} else if (results.size() > 0 &&
externalRewriterToMemIndex.count(op.getName()) != 0) {
writer.append(OpCode::ApplyConstraint,
externalRewriterToMemIndex[op.getName()]);
} else {
assert(true && "expected index for constraint function, make sure it is "
"registered properly. Note that native constraints with "
"results have to be registered as native rewriters.");
martin-luecke marked this conversation as resolved.
Show resolved Hide resolved
}
writer.appendPDLValueList(op.getArgs());
writer.append(ByteCodeField(results.size()));
for (Value result : results) {
// TODO: Handle result ranges
writer.append(result);
}
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
Expand Down Expand Up @@ -1405,7 +1422,7 @@ class ByteCodeRewriteResultList : public PDLResultList {

void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
ByteCodeField fun_idx = read();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);

Expand All @@ -1414,8 +1431,26 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
llvm::interleaveComma(args, llvm::dbgs());
});

// Invoke the constraint and jump to the proper destination.
selectJump(succeeded(constraintFn(rewriter, args)));
ByteCodeField numResults = read();
if (numResults == 0) {
const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx];
LogicalResult rewriteResult = constraintFn(rewriter, args);
// Depending on the constraint jump to the proper destination.
selectJump(succeeded(rewriteResult));
} else {
const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx];
ByteCodeRewriteResultList results(numResults);
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");

for (PDLValue &result : results.getResults()) {
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
memory[read()] = result.getAsOpaquePointer();
}
// Depending on the constraint jump to the proper destination.
selectJump(succeeded(rewriteResult));
}
}

LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
Expand Down
6 changes: 0 additions & 6 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1356,12 +1356,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
if (failed(parseToken(Token::semicolon,
"expected `;` after native declaration")))
return failure();
// TODO: PDL should be able to support constraint results in certain
// situations, we should revise this.
if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
return emitError(
"native Constraints currently do not support returning results");
}
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
}

Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,34 @@ module @constraints {

// -----

// CHECK-LABEL: module @constraint_with_result
module @constraint_with_result {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"(%attr : !pdl.attribute)
}
}

// -----

// CHECK-LABEL: module @constraint_with_unused_result
module @constraint_with_unused_result {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @inputs
module @inputs {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/PDL/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {

// -----

pdl.pattern @apply_constraint_with_no_results : benefit(1) {
%root = operation
apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
rewrite %root with "rewriter"
}

// -----

pdl.pattern @apply_constraint_with_results : benefit(1) {
%root = operation
%attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
rewrite %root {
apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
}
}

// -----

pdl.pattern @attribute_with_dict : benefit(1) {
%root = operation
rewrite %root {
Expand Down
31 changes: 31 additions & 0 deletions mlir/test/Rewrite/pdl-bytecode.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,37 @@ module @ir attributes { test.apply_constraint_2 } {

// -----

module @patterns {
pdl_interp.func @matcher(%root : !pdl.operation) {
pdl_interp.check_operation_name of %root is "test.op" -> ^bb0, ^end

^bb0:
%attr = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute -> ^pat, ^end

^pat:
pdl_interp.record_match @rewriters::@success(%root, %attr : !pdl.operation, !pdl.attribute) : benefit(1), loc([%root]) -> ^end

^end:
pdl_interp.finalize
}

module @rewriters {
pdl_interp.func @success(%root : !pdl.operation, %attr : !pdl.attribute) {
%op = pdl_interp.create_operation "test.success" {"attr" = %attr}
pdl_interp.erase %root
pdl_interp.finalize
}
}
}

// CHECK-LABEL: test.apply_constraint_3
// CHECK: "test.success"() {attr = "test.success"}
module @ir attributes { test.apply_constraint_3 } {
"test.op"() : () -> ()
}

// -----

//===----------------------------------------------------------------------===//
// pdl_interp::ApplyRewriteOp
//===----------------------------------------------------------------------===//
Expand Down
Loading