diff --git a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java index 9f011cf353a..8c23ef59a58 100644 --- a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java +++ b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/bigdecimaltest/BigDecimalTest.java @@ -492,4 +492,29 @@ public void testBigDecimalLiteralWithBinding() { Assertions.assertThat(result).containsExactly(new BigDecimal("0")); } + + @Test + public void testModifyWithNegativeBigDecimal() { + // DROOLS-7324 + String str = + "package org.drools.modelcompiler.bigdecimals\n" + + "import " + BdHolder.class.getCanonicalName() + ";\n" + + "global java.util.List result;\n" + + "rule R1 dialect \"mvel\" when\n" + + " $bd : BdHolder(bd1 > 5)\n" + + "then\n" + + " modify($bd) { bd1 = -1 }\n" + + "end"; + + KieSession ksession = getKieSession(str); + List result = new ArrayList<>(); + ksession.setGlobal("result", result); + + BdHolder holder = new BdHolder(); + holder.setBd1(new BigDecimal("10")); + ksession.insert(holder); + int fires = ksession.fireAllRules(); + + assertThat(fires).isEqualTo(1); + } } diff --git a/drools-model/drools-mvel-compiler/src/main/java/org/drools/mvelcompiler/ReProcessRHSPhase.java b/drools-model/drools-mvel-compiler/src/main/java/org/drools/mvelcompiler/ReProcessRHSPhase.java index 63f8fd8b14b..d036af6e30e 100644 --- a/drools-model/drools-mvel-compiler/src/main/java/org/drools/mvelcompiler/ReProcessRHSPhase.java +++ b/drools-model/drools-mvel-compiler/src/main/java/org/drools/mvelcompiler/ReProcessRHSPhase.java @@ -25,6 +25,7 @@ import com.github.javaparser.ast.expr.IntegerLiteralExpr; import com.github.javaparser.ast.expr.LongLiteralExpr; import com.github.javaparser.ast.expr.NameExpr; +import com.github.javaparser.ast.expr.UnaryExpr; import org.drools.mvel.parser.ast.visitor.DrlGenericVisitor; import org.drools.mvelcompiler.ast.BigDecimalConvertedExprT; import org.drools.mvelcompiler.ast.IntegerLiteralExpressionT; @@ -36,55 +37,73 @@ /** * Used when you need to reprocess the RHS after having processed the LHS */ -public class ReProcessRHSPhase implements DrlGenericVisitor, Void> { +public class ReProcessRHSPhase implements DrlGenericVisitor, ReProcessRHSPhase.Context> { private TypedExpression lhs; private MvelCompilerContext mvelCompilerContext; + static class Context { + private UnaryExpr unaryExpr; + + Context withUnaryExpr(UnaryExpr unaryExpr) { + this.unaryExpr = unaryExpr; + return this; + } + + Optional getUnaryExpr() { + return Optional.ofNullable(unaryExpr); + } + } + ReProcessRHSPhase(MvelCompilerContext mvelCompilerContext) { this.mvelCompilerContext = mvelCompilerContext; } public Optional invoke(TypedExpression rhs, TypedExpression lhs) { this.lhs = lhs; - return Optional.ofNullable(rhs).flatMap(r -> r.toJavaExpression().accept(this, null)); + return Optional.ofNullable(rhs).flatMap(r -> r.toJavaExpression().accept(this, new Context())); } @Override - public Optional defaultMethod(Node n, Void context) { + public Optional defaultMethod(Node n, ReProcessRHSPhase.Context context) { return Optional.empty(); } @Override - public Optional visit(BinaryExpr n, Void arg) { - return convertWhenLHSISBigDecimal(() -> new UnalteredTypedExpression(n)); + public Optional visit(UnaryExpr n, ReProcessRHSPhase.Context context) { + return n.getExpression().accept(this, context.withUnaryExpr(n)); + } + + @Override + public Optional visit(BinaryExpr n, ReProcessRHSPhase.Context context) { + return convertWhenLHSISBigDecimal(() -> new UnalteredTypedExpression(n), context); } @Override - public Optional visit(IntegerLiteralExpr n, Void arg) { - return convertWhenLHSISBigDecimal(() -> new IntegerLiteralExpressionT(n)); + public Optional visit(IntegerLiteralExpr n, ReProcessRHSPhase.Context context) { + return convertWhenLHSISBigDecimal(() -> new IntegerLiteralExpressionT(n), context); } @Override - public Optional visit(LongLiteralExpr n, Void arg) { - return convertWhenLHSISBigDecimal(() -> new LongLiteralExpressionT(n)); + public Optional visit(LongLiteralExpr n, ReProcessRHSPhase.Context context) { + return convertWhenLHSISBigDecimal(() -> new LongLiteralExpressionT(n), context); } @Override - public Optional visit(NameExpr n, Void arg) { + public Optional visit(NameExpr n, ReProcessRHSPhase.Context context) { if(mvelCompilerContext .findDeclarations(n.toString()) .filter(d -> d.getClazz() != BigDecimal.class) .isPresent()) { // avoid wrapping BigDecimal declarations - return convertWhenLHSISBigDecimal(() -> new UnalteredTypedExpression(n)); + return convertWhenLHSISBigDecimal(() -> new UnalteredTypedExpression(n), context); } else { return Optional.empty(); } } - private Optional convertWhenLHSISBigDecimal(Supplier conversionFunction) { + private Optional convertWhenLHSISBigDecimal(Supplier conversionFunction, ReProcessRHSPhase.Context context) { return lhs.getType() .filter(BigDecimal.class::equals) - .flatMap(t -> Optional.of(new BigDecimalConvertedExprT(conversionFunction.get()))); + .flatMap(t -> Optional.of(new BigDecimalConvertedExprT(conversionFunction.get(), context.getUnaryExpr()))); } } diff --git a/drools-model/drools-mvel-compiler/src/main/java/org/drools/mvelcompiler/ast/BigDecimalConvertedExprT.java b/drools-model/drools-mvel-compiler/src/main/java/org/drools/mvelcompiler/ast/BigDecimalConvertedExprT.java index 733e36f0a1f..057d09701dc 100644 --- a/drools-model/drools-mvel-compiler/src/main/java/org/drools/mvelcompiler/ast/BigDecimalConvertedExprT.java +++ b/drools-model/drools-mvel-compiler/src/main/java/org/drools/mvelcompiler/ast/BigDecimalConvertedExprT.java @@ -24,6 +24,7 @@ import com.github.javaparser.ast.Node; import com.github.javaparser.ast.expr.Expression; import com.github.javaparser.ast.expr.ObjectCreationExpr; +import com.github.javaparser.ast.expr.UnaryExpr; import static com.github.javaparser.ast.NodeList.nodeList; @@ -32,8 +33,15 @@ public class BigDecimalConvertedExprT implements TypedExpression { private final TypedExpression value; private final Type type = BigDecimal.class; + private final Optional unaryExpr; + public BigDecimalConvertedExprT(TypedExpression value) { + this(value, Optional.empty()); + } + + public BigDecimalConvertedExprT(TypedExpression value, Optional unaryExpr) { this.value = value; + this.unaryExpr = unaryExpr; } @Override @@ -43,10 +51,9 @@ public Optional getType() { @Override public Node toJavaExpression() { - - return new ObjectCreationExpr(null, - StaticJavaParser.parseClassOrInterfaceType(type.getTypeName()), - nodeList((Expression) value.toJavaExpression())); + Expression expr = (Expression) value.toJavaExpression(); + Expression arg = unaryExpr.map(u -> (Expression) new UnaryExpr(expr, u.getOperator())).orElse(expr); + return new ObjectCreationExpr(null, StaticJavaParser.parseClassOrInterfaceType(type.getTypeName()), nodeList(arg)); } @Override