Skip to content

Commit

Permalink
Merge pull request #130 from Xilinx/liangta.equal
Browse files Browse the repository at this point in the history
PDLL: Add "+" operator
  • Loading branch information
ge28boj authored Mar 14, 2024
2 parents d80bcf3 + 41f836d commit 8bc91d3
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 4 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
#ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_
#define MLIR_DIALECT_PDL_IR_BUILTINS_H_

#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"

namespace mlir {
class PDLPatternModule;
class Attribute;
Expand All @@ -29,6 +33,8 @@ Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter,
Attribute createArrayAttr(PatternRewriter &rewriter);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
LogicalResult add(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
} // namespace builtin
} // namespace pdl
} // namespace mlir
Expand Down
67 changes: 67 additions & 0 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
#include <cassert>
#include <cstdint>
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/APInt.h>
#include <llvm/ADT/ArrayRef.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/PDL/IR/Builtins.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Support/LogicalResult.h>

using namespace mlir;

Expand Down Expand Up @@ -39,6 +50,61 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
values.push_back(element);
return rewriter.getArrayAttr(values);
}

LogicalResult add(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
llvm::ArrayRef<mlir::PDLValue> args) {
assert(args.size() == 2 && "Expected 2 arguments");
auto lhsAttr = args[0].cast<Attribute>();
auto rhsAttr = args[1].cast<Attribute>();

// Integer
if (auto lhsIntAttr = dyn_cast_or_null<IntegerAttr>(lhsAttr)) {
auto rhsIntAttr = dyn_cast_or_null<IntegerAttr>(rhsAttr);
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType())
return failure();

auto integerType = lhsIntAttr.getType();

bool isOverflow;
llvm::APInt resultAPInt;
if (integerType.isUnsignedInteger() || integerType.isSignlessInteger()) {
resultAPInt =
lhsIntAttr.getValue().uadd_ov(rhsIntAttr.getValue(), isOverflow);
} else {
resultAPInt =
lhsIntAttr.getValue().sadd_ov(rhsIntAttr.getValue(), isOverflow);
}

if (isOverflow) {
return failure();
}

results.push_back(rewriter.getIntegerAttr(integerType, resultAPInt));
return success();
}

// Float
if (auto lhsFloatAttr = dyn_cast_or_null<FloatAttr>(lhsAttr)) {
auto rhsFloatAttr = dyn_cast_or_null<FloatAttr>(rhsAttr);
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType())
return failure();

APFloat lhsVal = lhsFloatAttr.getValue();
APFloat rhsVal = rhsFloatAttr.getValue();
APFloat resultVal(lhsVal);
auto floatType = lhsFloatAttr.getType();

bool isOverflow =
resultVal.add(rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
if (isOverflow) {
return failure();
}

results.push_back(rewriter.getFloatAttr(floatType, resultVal));
return success();
}
return failure();
}
} // namespace builtin

void registerBuiltins(PDLPatternModule &pdlPattern) {
Expand All @@ -52,5 +118,6 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
createArrayAttr);
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
addElemToArrayAttr);
pdlPattern.registerConstraintFunctionWithResults("__builtin_add", add);
}
} // namespace mlir::pdl
3 changes: 2 additions & 1 deletion mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ Token Lexer::lexToken() {
return formToken(Token::l_square, tokStart);
case ']':
return formToken(Token::r_square, tokStart);

case '+':
return formToken(Token::add, tokStart);
case '<':
return formToken(Token::less, tokStart);
case '>':
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Token {
equal_arrow,
semicolon,
/// Paired punctuation.
add,
less,
greater,
l_brace,
Expand Down
28 changes: 25 additions & 3 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ class Parser {
ast::UserRewriteDecl *addEntryToDictionaryAttr;
ast::UserRewriteDecl *createArrayAttr;
ast::UserRewriteDecl *addElemToArrayAttr;
ast::UserConstraintDecl *add;
} builtins{};
};
} // namespace
Expand Down Expand Up @@ -631,8 +632,9 @@ T *Parser::declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
}
popDeclScope();

auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc),
args, results, {}, attrTy);
auto *constraintDecl =
T::createNative(ctx, ast::Name::create(ctx, name, loc), args, results, {},
createUserConstraintRewriteResultType(results));
curDeclScope->add(constraintDecl);
return constraintDecl;
}
Expand All @@ -648,6 +650,9 @@ void Parser::declareBuiltins() {
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addElemToArrayAttr", {"attr", "element"},
/*returnsAttr=*/true);
builtins.add =
declareBuiltin<ast::UserConstraintDecl>("__builtin_add", {"lhs", "rhs"},
/*returnsAttr=*/true);
}

FailureOr<ast::Module *> Parser::parseModule() {
Expand Down Expand Up @@ -1909,7 +1914,24 @@ FailureOr<ast::Expr *> Parser::parseEqualityExpr() {

FailureOr<ast::Expr *> Parser::parseRelationExpr() { return parseAddSubExpr(); }

FailureOr<ast::Expr *> Parser::parseAddSubExpr() { return parseMulDivExpr(); }
FailureOr<ast::Expr *> Parser::parseAddSubExpr() {
auto lhs = parseMulDivExpr();
if (failed(lhs))
return failure();

switch (curToken.getKind()) {
case Token::add: {
consumeToken();
auto rhs = parseMulDivExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
return createBuiltinCall(curToken.getLoc(), builtins.add, args);
}
default:
return lhs;
}
}

FailureOr<ast::Expr *> Parser::parseMulDivExpr() {
return parseLogicalNotExpr();
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,13 @@ Pattern RewriteMultiplyElementsArrayAttr {
replace root with newRoot;
};
}

// -----
// CHECK-LABEL: pdl.pattern @TestAdd : benefit(0) {
// CHECK: %[[VAL_0:.*]] = attribute = 4 : i32
// CHECK: %[[VAL_1:.*]] = attribute = 5 : i32
// CHECK: apply_native_constraint "__builtin_add"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
Pattern TestAdd {
let val : Attr = attr<"4 : i32"> + attr<"5 : i32">;
replace op<test.simple> with op<test.success>;
}
20 changes: 20 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,23 @@ Pattern {
// CHECK: expected `>` after type literal
let foo = type<"";
}

//===----------------------------------------------------------------------===//
// Builtins
//===----------------------------------------------------------------------===//

// -----

Pattern {
// CHECK: expected expression
+
erase _: Op;
}

// -----

Pattern {
// CHECK: expected expression
let resultAttr : Attr = attr<"4 : i32"> +
erase _: Op;
}
15 changes: 15 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,18 @@ Pattern {

erase _: Op;
}

//===----------------------------------------------------------------------===//
// Builtins
//===----------------------------------------------------------------------===//

// -----

// CHECK: Module {{.*}}
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_add> ResultType<Attr>
Pattern {
let resultAttr : Attr = attr<"4 : f32"> + attr<"5 : i32">;
erase _: Op;
}


113 changes: 113 additions & 0 deletions mlir/unittests/Dialect/PDL/BuiltinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@

#include "mlir/Dialect/PDL/IR/Builtins.h"
#include "gmock/gmock.h"
#include <cstdint>
#include <gtest/gtest.h>
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/APInt.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/Region.h>

using namespace mlir;
using namespace mlir::pdl;
Expand All @@ -20,6 +28,13 @@ class TestPatternRewriter : public PatternRewriter {
TestPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
};

class TestPDLResultList : public PDLResultList {
public:
TestPDLResultList(unsigned maxNumResults) : PDLResultList(maxNumResults) {}
/// Return the list of PDL results.
MutableArrayRef<PDLValue> getResults() { return results; }
};

class BuiltinTest : public ::testing::Test {
public:
MLIRContext ctx;
Expand Down Expand Up @@ -69,4 +84,102 @@ TEST_F(BuiltinTest, addElemToArrayAttr) {
cast<DictionaryAttr>(*cast<ArrayAttr>(updatedArrAttr).begin());
EXPECT_EQ(dictInsideArrAttr, dict);
}

TEST_F(BuiltinTest, add) {
auto onei16 = rewriter.getI16IntegerAttr(1);
auto onei32 = rewriter.getI32IntegerAttr(1);
auto onei8 = rewriter.getI8IntegerAttr(1);
auto largesti8 = rewriter.getI8IntegerAttr(-1);

// check signless integer overflow
{
TestPDLResultList results(1);
EXPECT_TRUE(onei8.getType().isSignlessInteger());

EXPECT_TRUE(builtin::add(rewriter, results, {onei8, largesti8}).failed());
}

// check correctness of result
{
TestPDLResultList results(1);
EXPECT_TRUE(builtin::add(rewriter, results, {onei16, onei16}).succeeded());

PDLValue result = results.getResults()[0];
EXPECT_EQ(
cast<IntegerAttr>(result.cast<Attribute>()).getValue().getSExtValue(),
2);
}

IntegerType Uint8 = rewriter.getIntegerType(8, false);
auto oneUint8 = rewriter.getIntegerAttr(Uint8, APInt(8, 1, false));
auto largestUint8 = rewriter.getIntegerAttr(Uint8, APInt(8, 255, false));

// check unsigned integer overflow
{
TestPDLResultList results(1);
EXPECT_TRUE(
builtin::add(rewriter, results, {oneUint8, largestUint8}).failed());
}

IntegerType SInt8 = rewriter.getIntegerType(8, true);
auto oneSInt8 = rewriter.getIntegerAttr(SInt8, APInt(8, 1, true));
auto largestSInt8 = rewriter.getIntegerAttr(SInt8, APInt(8, 127, true));

// check signed integer overflow
{
TestPDLResultList results(1);
EXPECT_TRUE(
builtin::add(rewriter, results, {oneSInt8, largestSInt8}).failed());
}

// check integer types mismatch
{
TestPDLResultList results(1);
EXPECT_TRUE(builtin::add(rewriter, results, {onei16, onei32}).failed());
}

auto onef16 = rewriter.getF16FloatAttr(1.0);
auto onef32 = rewriter.getF32FloatAttr(1.0);
auto zerof32 = rewriter.getF32FloatAttr(0.0);
auto negzerof32 = rewriter.getF32FloatAttr(-0.0);
auto zerof64 = rewriter.getF64FloatAttr(0.0);

auto maxValF16 = rewriter.getF16FloatAttr(
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());

// check float overflow
{
TestPDLResultList results(1);
EXPECT_TRUE(builtin::add(rewriter, results, {onef16, maxValF16}).failed());
}

// check correctness of result
{
TestPDLResultList results(1);
EXPECT_TRUE(builtin::add(rewriter, results, {onef32, onef32}).succeeded());

PDLValue result = results.getResults()[0];
EXPECT_EQ(
cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat(),
2.0);
}

// check correctness of result
{
TestPDLResultList results(1);
EXPECT_TRUE(
builtin::add(rewriter, results, {zerof32, negzerof32}).succeeded());

PDLValue result = results.getResults()[0];
EXPECT_EQ(
cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat(),
0.0);
}

// check float types mismatch
{
TestPDLResultList results(1);
EXPECT_TRUE(builtin::add(rewriter, results, {zerof32, zerof64}).failed());
}
}
} // namespace

0 comments on commit 8bc91d3

Please sign in to comment.