Skip to content

Commit

Permalink
[KIE-775] drools executable-model fails with a bind variable to a cal…
Browse files Browse the repository at this point in the history
…culation result of int and BigDecimal (apache#5636)
  • Loading branch information
tkobayas authored and rgdoliveira committed Jan 16, 2024
1 parent aa6b676 commit 00e99f1
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right,
this.operator = operator;
}

/*
* This coercion only deals with String vs Numeric types.
* BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall()
*/
public ArithmeticCoercedExpressionResult coerce() {

if (!requiresCoercion()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@
import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.GREATER_THAN_PREFIX;
import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.LESS_OR_EQUAL_PREFIX;
import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.LESS_THAN_PREFIX;
import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.convertArithmeticBinaryToMethodCall;
import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.getBinaryTypeAfterConversion;
import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.shouldConvertArithmeticBinaryToMethodCall;
import static org.drools.util.StringUtils.lcFirstForBean;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.THIS_PLACEHOLDER;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.createConstraintCompiler;
Expand Down Expand Up @@ -197,11 +200,23 @@ private void logWarnIfNoReactOnCausedByVariableFromDifferentPattern(DrlxParseRes
}

private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleResult, String bindId) {
DeclarationSpec decl = context.addDeclaration( bindId, singleResult.getLeftExprTypeBeforeCoercion() );
DeclarationSpec decl = context.addDeclaration(bindId, getDeclarationType(drlx, singleResult));
if (drlx.getExpr() instanceof NameExpr) {
decl.setBoundVariable( PrintUtil.printNode(drlx.getExpr()) );
} else if (drlx.getExpr() instanceof BinaryExpr) {
decl.setBoundVariable(PrintUtil.printNode(drlx.getExpr().asBinaryExpr().getLeft()));
Expression leftMostExpression = getLeftMostExpression(drlx.getExpr().asBinaryExpr());
decl.setBoundVariable(PrintUtil.printNode(leftMostExpression));
if (singleResult.getExpr() instanceof MethodCallExpr) {
// BinaryExpr was converted to MethodCallExpr. Create a TypedExpression for the leftmost expression of the BinaryExpr
ExpressionTyperContext expressionTyperContext = new ExpressionTyperContext();
ExpressionTyper expressionTyper = new ExpressionTyper(context, singleResult.getPatternType(), bindId, false, expressionTyperContext);
TypedExpressionResult leftTypedExpressionResult = expressionTyper.toTypedExpression(leftMostExpression);
Optional<TypedExpression> optLeft = leftTypedExpressionResult.getTypedExpression();
if (!optLeft.isPresent()) {
throw new IllegalStateException("Cannot create TypedExpression for " + drlx.getExpr().asBinaryExpr().getLeft());
}
singleResult.setBoundExpr(optLeft.get());
}
}
decl.setBelongingPatternDescr(context.getCurrentPatternDescr());
singleResult.setExprBinding( bindId );
Expand All @@ -211,6 +226,24 @@ private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleRe
}
}

private static Class<?> getDeclarationType(DrlxExpression drlx, SingleDrlxParseSuccess singleResult) {
if (drlx.getBind() != null && drlx.getExpr() instanceof EnclosedExpr) {
// in case of enclosed, bind type should be the calculation result type
// If drlx.getBind() == null, a bind variable is inside the enclosed expression, leave it to the default behavior
return (Class<?>)singleResult.getExprType();
} else {
return singleResult.getLeftExprTypeBeforeCoercion();
}
}

private Expression getLeftMostExpression(BinaryExpr binaryExpr) {
Expression left = binaryExpr.getLeft();
if (left instanceof BinaryExpr) {
return getLeftMostExpression((BinaryExpr) left);
}
return left;
}

/*
This is the entry point for Constraint Transformation from a parsed MVEL constraint
to a Java Expression
Expand Down Expand Up @@ -657,17 +690,16 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class<?> patternT

Expression combo;

boolean arithmeticExpr = isArithmeticOperator(operator);
boolean isBetaConstraint = right.getExpression() != null && hasDeclarationFromOtherPattern( expressionTyperContext );
boolean requiresSplit = operator == BinaryExpr.Operator.AND && binaryExpr.getRight() instanceof HalfBinaryExpr && !isBetaConstraint;

Type exprType = isBooleanOperator( operator ) ? boolean.class : left.getType();

if (equalityExpr) {
combo = getEqualityExpression( left, right, operator ).expression;
} else if (arithmeticExpr && (left.isBigDecimal())) {
ConstraintCompiler constraintCompiler = createConstraintCompiler(this.context, of(patternType));
CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(binaryExpr);

combo = compiledExpressionResult.getExpression();
} else if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) {
combo = convertArithmeticBinaryToMethodCall(binaryExpr, of(patternType), this.context);
exprType = getBinaryTypeAfterConversion(left.getType(), right.getType());
} else {
if (left.getExpression() == null || right.getExpression() == null) {
return new DrlxParseFail(new ParseExpressionErrorResult(drlxExpr));
Expand Down Expand Up @@ -695,7 +727,7 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class<?> patternT
constraintType = Index.ConstraintType.FORALL_SELF_JOIN;
}

return new SingleDrlxParseSuccess(patternType, bindingId, combo, isBooleanOperator( operator ) ? boolean.class : left.getType())
return new SingleDrlxParseSuccess(patternType, bindingId, combo, exprType)
.setDecodeConstraintType( constraintType )
.setUsedDeclarations( expressionTyperContext.getUsedDeclarations() )
.setUsedDeclarationsOnLeft( usedDeclarationsOnLeft )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,14 @@ private Optional<TypedExpression> toTypedExpressionRec(Expression drlxExpr) {
right = coerced.getCoercedRight();

final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator);
return of(new TypedExpression(combo, left.getType()));

if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) {
Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext);
java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(left.getType(), right.getType());
return of(new TypedExpression(expression, binaryType));
} else {
return of(new TypedExpression(combo, left.getType()));
}
}

if (drlxExpr instanceof HalfBinaryExpr) {
Expand Down Expand Up @@ -806,11 +813,12 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) {
TypedExpression rightTypedExpression = right.getTypedExpression()
.orElseThrow(() -> new NoSuchElementException("TypedExpressionResult doesn't contain TypedExpression!"));
binaryExpr.setRight(rightTypedExpression.getExpression());
java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator());
if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), binaryType)) {
if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), leftTypedExpression.getType(), rightTypedExpression.getType())) {
Expression compiledExpression = convertArithmeticBinaryToMethodCall(binaryExpr, leftTypedExpression.getOriginalPatternType(), ruleContext);
java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(leftTypedExpression.getType(), rightTypedExpression.getType());
return new TypedExpressionCursor(compiledExpression, binaryType);
} else {
java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator());
return new TypedExpressionCursor(binaryExpr, binaryType);
}
}
Expand All @@ -819,7 +827,7 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) {
* Converts arithmetic binary expression (including coercion) to method call using ConstraintCompiler.
* This method can be generic, so we may centralize the calls in drools-model
*/
private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional<Class<?>> originalPatternType, RuleContext ruleContext) {
public static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional<Class<?>> originalPatternType, RuleContext ruleContext) {
ConstraintCompiler constraintCompiler = createConstraintCompiler(ruleContext, originalPatternType);
CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(printNode(binaryExpr));
return compiledExpressionResult.getExpression();
Expand All @@ -828,8 +836,15 @@ private static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryE
/*
* BigDecimal arithmetic operations should be converted to method calls. We may also apply this to BigInteger.
*/
private static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type type) {
return isArithmeticOperator(operator) && type.equals(BigDecimal.class);
public static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) {
return isArithmeticOperator(operator) && (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class));
}

/*
* After arithmetic to method call conversion, BigDecimal should take precedence regardless of left or right. We may also apply this to BigInteger.
*/
public static java.lang.reflect.Type getBinaryTypeAfterConversion(java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) {
return (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)) ? BigDecimal.class : leftType;
}

private java.lang.reflect.Type getBinaryType(TypedExpression leftTypedExpression, TypedExpression rightTypedExpression, Operator operator) {
Expand Down Expand Up @@ -936,6 +951,9 @@ private void promoteBigDecimalParameters(MethodCallExpr methodCallExpr, Class[]
Expression argumentExpression = methodCallExpr.getArgument(i);

if (argumentType != actualArgumentType) {
// unbind the original argumentExpression first, otherwise setArgument() will remove the argumentExpression from coercedExpression.childrenNodes
// It will result in failing to find DrlNameExpr in AST at DrlsParseUtil.transformDrlNameExprToNameExpr()
methodCallExpr.replace(argumentExpression, new NameExpr("placeholder"));
Expression coercedExpression = new BigDecimalArgumentCoercion().coercedArgument(argumentType, actualArgumentType, argumentExpression);
methodCallExpr.setArgument(i, coercedExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -772,4 +772,45 @@ public void bigDecimalCoercionInNestedMethodArgument_shouldNotFailToBuild() {
public static String intToString(int value) {
return Integer.toString(value);
}

@Test
public void bindVariableToBigDecimalCoercion2Operands_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : (1000 * value1)");
}

@Test
public void bindVariableToBigDecimalCoercion3Operands_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : (100000 * value1 / 100)");
}

@Test
public void bindVariableToBigDecimalCoercion3OperandsWithParentheses_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : ((100000 * value1) / 100)");
}

private void bindVariableToBigDecimalCoercion(String binding) {
// KIE-775
String str =
"package org.drools.modelcompiler.bigdecimals\n" +
"import " + BDFact.class.getCanonicalName() + ";\n" +
"global java.util.List result;\n" +
"rule R1\n" +
" when\n" +
" BDFact( " + binding + " )\n" +
" then\n" +
" result.add($var);\n" +
"end";

KieSession ksession = getKieSession(str);
List<BigDecimal> result = new ArrayList<>();
ksession.setGlobal("result", result);

BDFact bdFact = new BDFact();
bdFact.setValue1(new BigDecimal("80"));

ksession.insert(bdFact);
ksession.fireAllRules();

assertThat(result).contains(new BigDecimal("80000"));
}
}

0 comments on commit 00e99f1

Please sign in to comment.