Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PDLL: Allow to define builtin native calls #93

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===- Builtins.h - Builtin functions of the PDL dialect --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines builtin functions of the PDL dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_
#define MLIR_DIALECT_PDL_IR_BUILTINS_H_

namespace mlir {
class PDLPatternModule;
class Attribute;
class PatternRewriter;

namespace pdl {
void registerBuiltins(PDLPatternModule &pdlPattern);

namespace builtin {
Attribute createDictionaryAttr(PatternRewriter &rewriter);
Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter,
Attribute dictAttr, Attribute attrName,
Attribute attrEntry);
Attribute createArrayAttr(PatternRewriter &rewriter);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
} // namespace builtin
} // namespace pdl
} // namespace mlir

#endif // MLIR_DIALECT_PDL_IR_BUILTINS_H_
56 changes: 56 additions & 0 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include <mlir/Dialect/PDL/IR/Builtins.h>
#include <mlir/IR/PatternMatch.h>

using namespace mlir;

namespace mlir::pdl {
namespace builtin {
mlir::Attribute createDictionaryAttr(mlir::PatternRewriter &rewriter) {
return rewriter.getDictionaryAttr({});
}

mlir::Attribute addEntryToDictionaryAttr(mlir::PatternRewriter &rewriter,
mlir::Attribute dictAttr,
mlir::Attribute attrName,
mlir::Attribute attrEntry) {
assert(isa<DictionaryAttr>(dictAttr));
auto attr = dictAttr.cast<DictionaryAttr>();
auto name = attrName.cast<StringAttr>();
std::vector<NamedAttribute> values = attr.getValue().vec();

// Remove entry if it exists in the dictionary.
llvm::erase_if(values, [&](NamedAttribute &namedAttr) {
return namedAttr.getName() == name.getValue();
});

values.push_back(rewriter.getNamedAttr(name, attrEntry));
return rewriter.getDictionaryAttr(values);
}

mlir::Attribute createArrayAttr(mlir::PatternRewriter &rewriter) {
return rewriter.getArrayAttr({});
}

mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
mlir::Attribute attr,
mlir::Attribute element) {
mgehre-amd marked this conversation as resolved.
Show resolved Hide resolved
assert(isa<ArrayAttr>(attr));
auto values = cast<ArrayAttr>(attr).getValue().vec();
values.push_back(element);
return rewriter.getArrayAttr(values);
}
} // namespace builtin

void registerBuiltins(PDLPatternModule &pdlPattern) {
using namespace builtin;
// See Parser::defineBuiltins()
pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr",
createDictionaryAttr);
pdlPattern.registerRewriteFunction("__builtin_addEntryToDictionaryAttr",
addEntryToDictionaryAttr);
pdlPattern.registerRewriteFunction("__builtin_createArrayAttr",
createArrayAttr);
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
addElemToArrayAttr);
}
} // namespace mlir::pdl
1 change: 1 addition & 0 deletions mlir/lib/Dialect/PDL/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRPDLDialect
Builtins.cpp
PDL.cpp
PDLTypes.cpp

Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include <mlir/Dialect/PDL/IR/Builtins.h>
#include <optional>

using namespace mlir;
Expand Down Expand Up @@ -132,6 +133,8 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
llvm::report_fatal_error(
"failed to lower PDL pattern module to the PDL Interpreter");

pdl::registerBuiltins(pdlPatterns);

// Generate the pdl bytecode.
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
pdlModule, pdlPatterns.takeConfigs(), configMap,
Expand Down
105 changes: 77 additions & 28 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ class Parser {
/// Pop the last decl scope from the lexer.
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }

/// Creates a native constraint taking a set of Attr as arguments.
/// The number of arguments and their names is given by argNames.
/// The native returns an Attr when returnsAttr is true, otherwise returns
/// nothing.
template <class T>
T *declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
bool returnsAttr);

/// Register all builtin natives.
void declareBuiltins();

/// Parse the body of an AST module.
LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);

Expand Down Expand Up @@ -418,12 +429,12 @@ class Parser {
FailureOr<ast::MemberAccessExpr *>
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);

// Create a native call with \p nativeFuncName and \p arguments.
// Create a native call with \p function and \p arguments.
// This should be accompanied by a C++ implementation of the function that
// needs to be linked and registered in passes that process PDLL files.
FailureOr<ast::DeclRefExpr *>
createNativeCall(SMRange loc, StringRef nativeFuncName,
MutableArrayRef<ast::Expr *> arguments);
FailureOr<ast::Expr *>
createBuiltinCall(SMRange loc, ast::Decl *function,
MutableArrayRef<ast::Expr *> arguments);

/// Validate the member access `name` into the given parent expression. On
/// success, this also returns the type of the member accessed.
Expand Down Expand Up @@ -578,13 +589,64 @@ class Parser {

/// The optional code completion context.
CodeCompleteContext *codeCompleteContext;

struct {
ast::UserRewriteDecl *createDictionaryAttr;
ast::UserRewriteDecl *addEntryToDictionaryAttr;
ast::UserRewriteDecl *createArrayAttr;
ast::UserRewriteDecl *addElemToArrayAttr;
} builtins{};
};
} // namespace

template <class T>
T *Parser::declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
bool returnsAttr) {
SMRange loc;
auto attrConstr = ast::ConstraintRef(
ast::AttrConstraintDecl::create(ctx, loc, nullptr), loc);

pushDeclScope();
SmallVector<ast::VariableDecl *> args;
for (auto argName : argNames) {
FailureOr<ast::VariableDecl *> arg =
createArgOrResultVariableDecl(argName, loc, attrConstr);
assert(succeeded(arg));
args.push_back(*arg);
}
SmallVector<ast::VariableDecl *> results;
if (returnsAttr) {
auto result = createArgOrResultVariableDecl("", loc, attrConstr);
assert(succeeded(result));
results.push_back(*result);
}
popDeclScope();

auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc),
args, results, {}, attrTy);
curDeclScope->add(constraintDecl);
return constraintDecl;
}

void Parser::declareBuiltins() {
builtins.createDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_createDictionaryAttr", {}, /*returnsAttr=*/true);
builtins.addEntryToDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addEntryToDictionaryAttr", {"attr", "attrName", "attrEntry"},
/*returnsAttr=*/true);
builtins.createArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_createArrayAttr", {}, /*returnsAttr=*/true);
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addElemToArrayAttr", {"attr", "element"},
/*returnsAttr=*/true);
}

FailureOr<ast::Module *> Parser::parseModule() {
SMLoc moduleLoc = curToken.getStartLoc();
pushDeclScope();

declareBuiltins();

// Parse the top-level decls of the module.
SmallVector<ast::Decl *> decls;
if (failed(parseModuleBody(decls)))
Expand Down Expand Up @@ -1874,7 +1936,7 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
"Parsing of array attributes as constraint not supported!");

auto arrayAttrCall =
createNativeCall(curToken.getLoc(), "createArrayAttr", {});
createBuiltinCall(curToken.getLoc(), builtins.createArrayAttr, {});
if (failed(arrayAttrCall))
return failure();

Expand All @@ -1884,8 +1946,8 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
return failure();

SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
auto elemToArrayCall = createNativeCall(
curToken.getLoc(), "addElemToArrayAttr", arrayAttrArgs);
auto elemToArrayCall = createBuiltinCall(
curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs);
if (failed(elemToArrayCall))
return failure();

Expand Down Expand Up @@ -1966,7 +2028,7 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
return emitError(
"Parsing of dictionary attributes as constraint not supported!");

auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {});
auto dictAttrCall = createBuiltinCall(loc, builtins.createDictionaryAttr, {});
if (failed(dictAttrCall))
return failure();

Expand Down Expand Up @@ -2000,8 +2062,8 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
// Create addEntryToDictionaryAttr native call.
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
namedDecl->getValue()};
auto entryToDictionaryCall =
createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs);
auto entryToDictionaryCall = createBuiltinCall(
loc, builtins.addEntryToDictionaryAttr, arrayAttrArgs);
if (failed(entryToDictionaryCall))
return failure();

Expand Down Expand Up @@ -2923,33 +2985,20 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
}

FailureOr<ast::DeclRefExpr *>
Parser::createNativeCall(SMRange loc, StringRef nativeFuncName,
MutableArrayRef<ast::Expr *> arguments) {
FailureOr<ast::Expr *>
Parser::createBuiltinCall(SMRange loc, ast::Decl *function,
MutableArrayRef<ast::Expr *> arguments) {

FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc);
FailureOr<ast::Expr *> nativeFuncExpr = createDeclRefExpr(loc, function);
if (failed(nativeFuncExpr))
return failure();

if (!(*nativeFuncExpr)->getType().isa<ast::RewriteType>())
return emitError(nativeFuncName + " should be defined as a rewriter.");

FailureOr<ast::CallExpr *> nativeCall =
createCallExpr(loc, *nativeFuncExpr, arguments);
if (failed(nativeCall))
return failure();

// Create a unique anonymous name declaration to use, as its name is not
// important.
std::string anonName =
llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++)
.str();
FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl(
anonName, loc, (*nativeCall)->getType(), *nativeCall, {});
if (failed(varDecl))
return failure();

return createDeclRefExpr(loc, *varDecl);
return *nativeCall;
}

FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
Expand Down
Loading
Loading