Skip to content

Commit

Permalink
Merge pull request #93 from Xilinx/matthias.pdl_builtin
Browse files Browse the repository at this point in the history
PDLL: Allow to define builtin native calls
  • Loading branch information
mgehre-amd authored Feb 21, 2024
2 parents fc8ecac + 8ef9dcd commit e6b751e
Show file tree
Hide file tree
Showing 11 changed files with 277 additions and 91 deletions.
36 changes: 36 additions & 0 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===- Builtins.h - Builtin functions of the PDL dialect --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines builtin functions of the PDL dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_
#define MLIR_DIALECT_PDL_IR_BUILTINS_H_

namespace mlir {
class PDLPatternModule;
class Attribute;
class PatternRewriter;

namespace pdl {
void registerBuiltins(PDLPatternModule &pdlPattern);

namespace builtin {
Attribute createDictionaryAttr(PatternRewriter &rewriter);
Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter,
Attribute dictAttr, Attribute attrName,
Attribute attrEntry);
Attribute createArrayAttr(PatternRewriter &rewriter);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
} // namespace builtin
} // namespace pdl
} // namespace mlir

#endif // MLIR_DIALECT_PDL_IR_BUILTINS_H_
56 changes: 56 additions & 0 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include <mlir/Dialect/PDL/IR/Builtins.h>
#include <mlir/IR/PatternMatch.h>

using namespace mlir;

namespace mlir::pdl {
namespace builtin {
mlir::Attribute createDictionaryAttr(mlir::PatternRewriter &rewriter) {
return rewriter.getDictionaryAttr({});
}

mlir::Attribute addEntryToDictionaryAttr(mlir::PatternRewriter &rewriter,
mlir::Attribute dictAttr,
mlir::Attribute attrName,
mlir::Attribute attrEntry) {
assert(isa<DictionaryAttr>(dictAttr));
auto attr = dictAttr.cast<DictionaryAttr>();
auto name = attrName.cast<StringAttr>();
std::vector<NamedAttribute> values = attr.getValue().vec();

// Remove entry if it exists in the dictionary.
llvm::erase_if(values, [&](NamedAttribute &namedAttr) {
return namedAttr.getName() == name.getValue();
});

values.push_back(rewriter.getNamedAttr(name, attrEntry));
return rewriter.getDictionaryAttr(values);
}

mlir::Attribute createArrayAttr(mlir::PatternRewriter &rewriter) {
return rewriter.getArrayAttr({});
}

mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
mlir::Attribute attr,
mlir::Attribute element) {
assert(isa<ArrayAttr>(attr));
auto values = cast<ArrayAttr>(attr).getValue().vec();
values.push_back(element);
return rewriter.getArrayAttr(values);
}
} // namespace builtin

void registerBuiltins(PDLPatternModule &pdlPattern) {
using namespace builtin;
// See Parser::defineBuiltins()
pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr",
createDictionaryAttr);
pdlPattern.registerRewriteFunction("__builtin_addEntryToDictionaryAttr",
addEntryToDictionaryAttr);
pdlPattern.registerRewriteFunction("__builtin_createArrayAttr",
createArrayAttr);
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
addElemToArrayAttr);
}
} // namespace mlir::pdl
1 change: 1 addition & 0 deletions mlir/lib/Dialect/PDL/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRPDLDialect
Builtins.cpp
PDL.cpp
PDLTypes.cpp

Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include <mlir/Dialect/PDL/IR/Builtins.h>
#include <optional>

using namespace mlir;
Expand Down Expand Up @@ -132,6 +133,8 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
llvm::report_fatal_error(
"failed to lower PDL pattern module to the PDL Interpreter");

pdl::registerBuiltins(pdlPatterns);

// Generate the pdl bytecode.
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
pdlModule, pdlPatterns.takeConfigs(), configMap,
Expand Down
105 changes: 77 additions & 28 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ class Parser {
/// Pop the last decl scope from the lexer.
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }

/// Creates a native constraint taking a set of Attr as arguments.
/// The number of arguments and their names is given by argNames.
/// The native returns an Attr when returnsAttr is true, otherwise returns
/// nothing.
template <class T>
T *declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
bool returnsAttr);

/// Register all builtin natives.
void declareBuiltins();

/// Parse the body of an AST module.
LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);

Expand Down Expand Up @@ -418,12 +429,12 @@ class Parser {
FailureOr<ast::MemberAccessExpr *>
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);

// Create a native call with \p nativeFuncName and \p arguments.
// Create a native call with \p function and \p arguments.
// This should be accompanied by a C++ implementation of the function that
// needs to be linked and registered in passes that process PDLL files.
FailureOr<ast::DeclRefExpr *>
createNativeCall(SMRange loc, StringRef nativeFuncName,
MutableArrayRef<ast::Expr *> arguments);
FailureOr<ast::Expr *>
createBuiltinCall(SMRange loc, ast::Decl *function,
MutableArrayRef<ast::Expr *> arguments);

/// Validate the member access `name` into the given parent expression. On
/// success, this also returns the type of the member accessed.
Expand Down Expand Up @@ -578,13 +589,64 @@ class Parser {

/// The optional code completion context.
CodeCompleteContext *codeCompleteContext;

struct {
ast::UserRewriteDecl *createDictionaryAttr;
ast::UserRewriteDecl *addEntryToDictionaryAttr;
ast::UserRewriteDecl *createArrayAttr;
ast::UserRewriteDecl *addElemToArrayAttr;
} builtins{};
};
} // namespace

template <class T>
T *Parser::declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
bool returnsAttr) {
SMRange loc;
auto attrConstr = ast::ConstraintRef(
ast::AttrConstraintDecl::create(ctx, loc, nullptr), loc);

pushDeclScope();
SmallVector<ast::VariableDecl *> args;
for (auto argName : argNames) {
FailureOr<ast::VariableDecl *> arg =
createArgOrResultVariableDecl(argName, loc, attrConstr);
assert(succeeded(arg));
args.push_back(*arg);
}
SmallVector<ast::VariableDecl *> results;
if (returnsAttr) {
auto result = createArgOrResultVariableDecl("", loc, attrConstr);
assert(succeeded(result));
results.push_back(*result);
}
popDeclScope();

auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc),
args, results, {}, attrTy);
curDeclScope->add(constraintDecl);
return constraintDecl;
}

void Parser::declareBuiltins() {
builtins.createDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_createDictionaryAttr", {}, /*returnsAttr=*/true);
builtins.addEntryToDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addEntryToDictionaryAttr", {"attr", "attrName", "attrEntry"},
/*returnsAttr=*/true);
builtins.createArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_createArrayAttr", {}, /*returnsAttr=*/true);
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addElemToArrayAttr", {"attr", "element"},
/*returnsAttr=*/true);
}

FailureOr<ast::Module *> Parser::parseModule() {
SMLoc moduleLoc = curToken.getStartLoc();
pushDeclScope();

declareBuiltins();

// Parse the top-level decls of the module.
SmallVector<ast::Decl *> decls;
if (failed(parseModuleBody(decls)))
Expand Down Expand Up @@ -1874,7 +1936,7 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
"Parsing of array attributes as constraint not supported!");

auto arrayAttrCall =
createNativeCall(curToken.getLoc(), "createArrayAttr", {});
createBuiltinCall(curToken.getLoc(), builtins.createArrayAttr, {});
if (failed(arrayAttrCall))
return failure();

Expand All @@ -1884,8 +1946,8 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
return failure();

SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
auto elemToArrayCall = createNativeCall(
curToken.getLoc(), "addElemToArrayAttr", arrayAttrArgs);
auto elemToArrayCall = createBuiltinCall(
curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs);
if (failed(elemToArrayCall))
return failure();

Expand Down Expand Up @@ -1966,7 +2028,7 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
return emitError(
"Parsing of dictionary attributes as constraint not supported!");

auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {});
auto dictAttrCall = createBuiltinCall(loc, builtins.createDictionaryAttr, {});
if (failed(dictAttrCall))
return failure();

Expand Down Expand Up @@ -2000,8 +2062,8 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
// Create addEntryToDictionaryAttr native call.
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
namedDecl->getValue()};
auto entryToDictionaryCall =
createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs);
auto entryToDictionaryCall = createBuiltinCall(
loc, builtins.addEntryToDictionaryAttr, arrayAttrArgs);
if (failed(entryToDictionaryCall))
return failure();

Expand Down Expand Up @@ -2923,33 +2985,20 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
}

FailureOr<ast::DeclRefExpr *>
Parser::createNativeCall(SMRange loc, StringRef nativeFuncName,
MutableArrayRef<ast::Expr *> arguments) {
FailureOr<ast::Expr *>
Parser::createBuiltinCall(SMRange loc, ast::Decl *function,
MutableArrayRef<ast::Expr *> arguments) {

FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc);
FailureOr<ast::Expr *> nativeFuncExpr = createDeclRefExpr(loc, function);
if (failed(nativeFuncExpr))
return failure();

if (!(*nativeFuncExpr)->getType().isa<ast::RewriteType>())
return emitError(nativeFuncName + " should be defined as a rewriter.");

FailureOr<ast::CallExpr *> nativeCall =
createCallExpr(loc, *nativeFuncExpr, arguments);
if (failed(nativeCall))
return failure();

// Create a unique anonymous name declaration to use, as its name is not
// important.
std::string anonName =
llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++)
.str();
FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl(
anonName, loc, (*nativeCall)->getType(), *nativeCall, {});
if (failed(varDecl))
return failure();

return createDeclRefExpr(loc, *varDecl);
return *nativeCall;
}

FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
Expand Down
Loading

0 comments on commit e6b751e

Please sign in to comment.