From 5a8f65d4bb98220f9f728cb8501bc477c65a9bdf Mon Sep 17 00:00:00 2001 From: Toshiya Kobayashi Date: Tue, 19 Dec 2023 12:24:00 +0900 Subject: [PATCH] [KIE-775] drools executable-model fails with a bind variable to a calculation result of int and BigDecimal --- .../ArithmeticCoercedExpression.java | 4 ++ .../generator/drlxparse/ConstraintParser.java | 50 +++++++++++++++---- .../expressiontyper/ExpressionTyper.java | 30 ++++++++--- .../bigdecimaltest/BigDecimalTest.java | 41 +++++++++++++++ 4 files changed, 110 insertions(+), 15 deletions(-) diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java index 3cdea08ab37..765a6b313e2 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java @@ -52,6 +52,10 @@ public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right, this.operator = operator; } + /* + * This coercion only deals with String vs Numeric types. + * BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall() + */ public ArithmeticCoercedExpressionResult coerce() { if (!requiresCoercion()) { diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ConstraintParser.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ConstraintParser.java index 2a85da6e2f0..36b8f67a293 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ConstraintParser.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ConstraintParser.java @@ -95,6 +95,9 @@ import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.GREATER_THAN_PREFIX; import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.LESS_OR_EQUAL_PREFIX; import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.LESS_THAN_PREFIX; +import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.convertArithmeticBinaryToMethodCall; +import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.getBinaryTypeAfterConversion; +import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.shouldConvertArithmeticBinaryToMethodCall; import static org.drools.util.StringUtils.lcFirstForBean; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.THIS_PLACEHOLDER; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.createConstraintCompiler; @@ -197,11 +200,23 @@ private void logWarnIfNoReactOnCausedByVariableFromDifferentPattern(DrlxParseRes } private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleResult, String bindId) { - DeclarationSpec decl = context.addDeclaration( bindId, singleResult.getLeftExprTypeBeforeCoercion() ); + DeclarationSpec decl = context.addDeclaration(bindId, getDeclarationType(drlx, singleResult)); if (drlx.getExpr() instanceof NameExpr) { decl.setBoundVariable( PrintUtil.printNode(drlx.getExpr()) ); } else if (drlx.getExpr() instanceof BinaryExpr) { - decl.setBoundVariable(PrintUtil.printNode(drlx.getExpr().asBinaryExpr().getLeft())); + Expression leftMostExpression = getLeftMostExpression(drlx.getExpr().asBinaryExpr()); + decl.setBoundVariable(PrintUtil.printNode(leftMostExpression)); + if (singleResult.getExpr() instanceof MethodCallExpr) { + // BinaryExpr was converted to MethodCallExpr. Create a TypedExpression for the leftmost expression of the BinaryExpr + ExpressionTyperContext expressionTyperContext = new ExpressionTyperContext(); + ExpressionTyper expressionTyper = new ExpressionTyper(context, singleResult.getPatternType(), bindId, false, expressionTyperContext); + TypedExpressionResult leftTypedExpressionResult = expressionTyper.toTypedExpression(leftMostExpression); + Optional optLeft = leftTypedExpressionResult.getTypedExpression(); + if (!optLeft.isPresent()) { + throw new IllegalStateException("Cannot create TypedExpression for " + drlx.getExpr().asBinaryExpr().getLeft()); + } + singleResult.setBoundExpr(optLeft.get()); + } } decl.setBelongingPatternDescr(context.getCurrentPatternDescr()); singleResult.setExprBinding( bindId ); @@ -211,6 +226,24 @@ private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleRe } } + private static Class getDeclarationType(DrlxExpression drlx, SingleDrlxParseSuccess singleResult) { + if (drlx.getBind() != null && drlx.getExpr() instanceof EnclosedExpr) { + // in case of enclosed, bind type should be the calculation result type + // If drlx.getBind() == null, a bind variable is inside the enclosed expression, leave it to the default behavior + return (Class)singleResult.getExprType(); + } else { + return singleResult.getLeftExprTypeBeforeCoercion(); + } + } + + private Expression getLeftMostExpression(BinaryExpr binaryExpr) { + Expression left = binaryExpr.getLeft(); + if (left instanceof BinaryExpr) { + return getLeftMostExpression((BinaryExpr) left); + } + return left; + } + /* This is the entry point for Constraint Transformation from a parsed MVEL constraint to a Java Expression @@ -657,17 +690,16 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class patternT Expression combo; - boolean arithmeticExpr = isArithmeticOperator(operator); boolean isBetaConstraint = right.getExpression() != null && hasDeclarationFromOtherPattern( expressionTyperContext ); boolean requiresSplit = operator == BinaryExpr.Operator.AND && binaryExpr.getRight() instanceof HalfBinaryExpr && !isBetaConstraint; + Type exprType = isBooleanOperator( operator ) ? boolean.class : left.getType(); + if (equalityExpr) { combo = getEqualityExpression( left, right, operator ).expression; - } else if (arithmeticExpr && (left.isBigDecimal())) { - ConstraintCompiler constraintCompiler = createConstraintCompiler(this.context, of(patternType)); - CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(binaryExpr); - - combo = compiledExpressionResult.getExpression(); + } else if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { + combo = convertArithmeticBinaryToMethodCall(binaryExpr, of(patternType), this.context); + exprType = getBinaryTypeAfterConversion(left.getType(), right.getType()); } else { if (left.getExpression() == null || right.getExpression() == null) { return new DrlxParseFail(new ParseExpressionErrorResult(drlxExpr)); @@ -695,7 +727,7 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class patternT constraintType = Index.ConstraintType.FORALL_SELF_JOIN; } - return new SingleDrlxParseSuccess(patternType, bindingId, combo, isBooleanOperator( operator ) ? boolean.class : left.getType()) + return new SingleDrlxParseSuccess(patternType, bindingId, combo, exprType) .setDecodeConstraintType( constraintType ) .setUsedDeclarations( expressionTyperContext.getUsedDeclarations() ) .setUsedDeclarationsOnLeft( usedDeclarationsOnLeft ) diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java index 328aa5466e5..8fdfb696c46 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java @@ -235,7 +235,14 @@ private Optional toTypedExpressionRec(Expression drlxExpr) { right = coerced.getCoercedRight(); final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator); - return of(new TypedExpression(combo, left.getType())); + + if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { + Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext); + java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(left.getType(), right.getType()); + return of(new TypedExpression(expression, binaryType)); + } else { + return of(new TypedExpression(combo, left.getType())); + } } if (drlxExpr instanceof HalfBinaryExpr) { @@ -806,11 +813,12 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) { TypedExpression rightTypedExpression = right.getTypedExpression() .orElseThrow(() -> new NoSuchElementException("TypedExpressionResult doesn't contain TypedExpression!")); binaryExpr.setRight(rightTypedExpression.getExpression()); - java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator()); - if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), binaryType)) { + if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), leftTypedExpression.getType(), rightTypedExpression.getType())) { Expression compiledExpression = convertArithmeticBinaryToMethodCall(binaryExpr, leftTypedExpression.getOriginalPatternType(), ruleContext); + java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(leftTypedExpression.getType(), rightTypedExpression.getType()); return new TypedExpressionCursor(compiledExpression, binaryType); } else { + java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator()); return new TypedExpressionCursor(binaryExpr, binaryType); } } @@ -819,7 +827,7 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) { * Converts arithmetic binary expression (including coercion) to method call using ConstraintCompiler. * This method can be generic, so we may centralize the calls in drools-model */ - private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional> originalPatternType, RuleContext ruleContext) { + public static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional> originalPatternType, RuleContext ruleContext) { ConstraintCompiler constraintCompiler = createConstraintCompiler(ruleContext, originalPatternType); CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(printNode(binaryExpr)); return compiledExpressionResult.getExpression(); @@ -828,8 +836,15 @@ private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryE /* * BigDecimal arithmetic operations should be converted to method calls. We may also apply this to BigInteger. */ - private static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type type) { - return isArithmeticOperator(operator) && type.equals(BigDecimal.class); + public static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) { + return isArithmeticOperator(operator) && (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)); + } + + /* + * After arithmetic to method call conversion, BigDecimal should take precedence regardless of left or right. We may also apply this to BigInteger. + */ + public static java.lang.reflect.Type getBinaryTypeAfterConversion(java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) { + return (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)) ? BigDecimal.class : leftType; } private java.lang.reflect.Type getBinaryType(TypedExpression leftTypedExpression, TypedExpression rightTypedExpression, Operator operator) { @@ -936,6 +951,9 @@ private void promoteBigDecimalParameters(MethodCallExpr methodCallExpr, Class[] Expression argumentExpression = methodCallExpr.getArgument(i); if (argumentType != actualArgumentType) { + // unbind the original argumentExpression first, otherwise setArgument() will remove the argumentExpression from coercedExpression.childrenNodes + // It will result in failing to find DrlNameExpr in AST at DrlsParseUtil.transformDrlNameExprToNameExpr() + methodCallExpr.replace(argumentExpression, new NameExpr("placeholder")); Expression coercedExpression = new BigDecimalArgumentCoercion().coercedArgument(argumentType, actualArgumentType, argumentExpression); methodCallExpr.setArgument(i, coercedExpression); } diff --git a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/bigdecimaltest/BigDecimalTest.java b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/bigdecimaltest/BigDecimalTest.java index 0cc90ca3179..407462de392 100644 --- a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/bigdecimaltest/BigDecimalTest.java +++ b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/bigdecimaltest/BigDecimalTest.java @@ -772,4 +772,45 @@ public void bigDecimalCoercionInNestedMethodArgument_shouldNotFailToBuild() { public static String intToString(int value) { return Integer.toString(value); } + + @Test + public void bindVariableToBigDecimalCoercion2Operands_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : (1000 * value1)"); + } + + @Test + public void bindVariableToBigDecimalCoercion3Operands_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : (100000 * value1 / 100)"); + } + + @Test + public void bindVariableToBigDecimalCoercion3OperandsWithParentheses_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : ((100000 * value1) / 100)"); + } + + private void bindVariableToBigDecimalCoercion(String binding) { + // KIE-775 + String str = + "package org.drools.modelcompiler.bigdecimals\n" + + "import " + BDFact.class.getCanonicalName() + ";\n" + + "global java.util.List result;\n" + + "rule R1\n" + + " when\n" + + " BDFact( " + binding + " )\n" + + " then\n" + + " result.add($var);\n" + + "end"; + + KieSession ksession = getKieSession(str); + List result = new ArrayList<>(); + ksession.setGlobal("result", result); + + BDFact bdFact = new BDFact(); + bdFact.setValue1(new BigDecimal("80")); + + ksession.insert(bdFact); + ksession.fireAllRules(); + + assertThat(result).contains(new BigDecimal("80000")); + } }