From 31a681d15f6925eca4dac2010074fccac26035be Mon Sep 17 00:00:00 2001 From: Ferdinand Lemaire Date: Tue, 9 Jul 2024 16:07:54 +0100 Subject: [PATCH] Allow log2 pdll builtin to return a floating point value even if the log2 is not exact --- mlir/lib/Dialect/PDL/IR/Builtins.cpp | 2 +- mlir/unittests/Dialect/PDL/BuiltinTest.cpp | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index b8e1d9c87c0918..33216c1a0cef27 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -169,7 +169,7 @@ LogicalResult static unaryOp(PatternRewriter &rewriter, PDLResultList &results, } else if constexpr (T == UnaryOpKind::log2) { results.push_back(rewriter.getFloatAttr( operandFloatAttr.getType(), - (double)operandFloatAttr.getValue().getExactLog2())); + std::log2(operandFloatAttr.getValueAsDouble()))); } else if constexpr (T == UnaryOpKind::abs) { auto resultVal = operandFloatAttr.getValue(); resultVal.clearSign(); diff --git a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp index a86c08e78e923c..7413f0bffa0d82 100644 --- a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp +++ b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp @@ -634,6 +634,18 @@ TEST_F(BuiltinTest, log2) { cast(result.cast()).getValue().convertToFloat(), 2.0); } + + auto threeF16 = rewriter.getF16FloatAttr(3.0); + + // check correctness + { + TestPDLResultList results(1); + EXPECT_TRUE(builtin::log2(rewriter, results, {threeF16}).succeeded()); + + PDLValue result = results.getResults()[0]; + float resultVal = cast(result.cast()).getValue().convertToFloat(); + EXPECT_TRUE(resultVal > 1.58 && resultVal < 1.59); + } } TEST_F(BuiltinTest, exp2) {