From 186a8cdcca38106faef7eb0ddf85f87844c0929b Mon Sep 17 00:00:00 2001 From: morrySnow Date: Fri, 15 Nov 2024 20:09:25 +0800 Subject: [PATCH] [fix](Nereids) simplify comparison predicate do wrong cast related PR #41151 lead to analysis error ``` Invalid precision and scale ``` --- .../rules/SimplifyComparisonPredicate.java | 15 +++++++-------- .../rules/SimplifyComparisonPredicateTest.java | 12 ++++++++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index cc1694575ec032..cb61795865239b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -229,15 +229,15 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa left = cast.child(); DecimalV3Literal literal = (DecimalV3Literal) right; if (left.getDataType().isDecimalV3Type()) { - if (((DecimalV3Type) left.getDataType()) - .getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) { + DecimalV3Type leftType = (DecimalV3Type) left.getDataType(); + DecimalV3Type literalType = (DecimalV3Type) literal.getDataType(); + if (leftType.getScale() < literalType.getScale()) { int toScale = ((DecimalV3Type) left.getDataType()).getScale(); if (comparisonPredicate instanceof EqualTo) { try { return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) - comparisonPredicate.withChildren(left, - new DecimalV3Literal((DecimalV3Type) left.getDataType(), - literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); + comparisonPredicate.withChildren(left, new DecimalV3Literal( + literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); } catch (ArithmeticException e) { if (left.nullable()) { // TODO: the ideal way is to return an If expr like: @@ -255,9 +255,8 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa } else if (comparisonPredicate instanceof NullSafeEqual) { try { return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) - comparisonPredicate.withChildren(left, - new DecimalV3Literal((DecimalV3Type) left.getDataType(), - literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); + comparisonPredicate.withChildren(left, new DecimalV3Literal( + literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); } catch (ArithmeticException e) { return BooleanLiteral.of(false); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 402594d68610fa..84ebd7c7250198 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -279,5 +279,17 @@ void testDecimalV3Literal() { rewrittenExpression.child(0).getDataType()); Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // left's child range smaller than right literal + leftChild = new DecimalV3Literal(new BigDecimal("1234.12")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(10, 5)); + right = new DecimalV3Literal(new BigDecimal("12345.12000")); + expression = new EqualTo(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(7, 2), + rewrittenExpression.child(0).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("12345.12"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); } }