Skip to content

Commit

Permalink
[DROOLS-7324] fix executable model generation when setting negative B… (
Browse files Browse the repository at this point in the history
apache#4958)

* [DROOLS-7324] fix executable model generation when setting negative BigDecimal literal value

* wip

* wip

(cherry picked from commit 2eed163)
  • Loading branch information
mariofusco committed Feb 16, 2023
1 parent b598eb8 commit 03d1f40
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,29 @@ public void testBigDecimalLiteralWithBinding() {

Assertions.assertThat(result).containsExactly(new BigDecimal("0"));
}

@Test
public void testModifyWithNegativeBigDecimal() {
// DROOLS-7324
String str =
"package org.drools.modelcompiler.bigdecimals\n" +
"import " + BdHolder.class.getCanonicalName() + ";\n" +
"global java.util.List result;\n" +
"rule R1 dialect \"mvel\" when\n" +
" $bd : BdHolder(bd1 > 5)\n" +
"then\n" +
" modify($bd) { bd1 = -1 }\n" +
"end";

KieSession ksession = getKieSession(str);
List<BigDecimal> result = new ArrayList<>();
ksession.setGlobal("result", result);

BdHolder holder = new BdHolder();
holder.setBd1(new BigDecimal("10"));
ksession.insert(holder);
int fires = ksession.fireAllRules();

assertThat(fires).isEqualTo(1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.github.javaparser.ast.expr.IntegerLiteralExpr;
import com.github.javaparser.ast.expr.LongLiteralExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.UnaryExpr;
import org.drools.mvel.parser.ast.visitor.DrlGenericVisitor;
import org.drools.mvelcompiler.ast.BigDecimalConvertedExprT;
import org.drools.mvelcompiler.ast.IntegerLiteralExpressionT;
Expand All @@ -36,55 +37,73 @@
/**
* Used when you need to reprocess the RHS after having processed the LHS
*/
public class ReProcessRHSPhase implements DrlGenericVisitor<Optional<TypedExpression>, Void> {
public class ReProcessRHSPhase implements DrlGenericVisitor<Optional<TypedExpression>, ReProcessRHSPhase.Context> {

private TypedExpression lhs;
private MvelCompilerContext mvelCompilerContext;

static class Context {
private UnaryExpr unaryExpr;

Context withUnaryExpr(UnaryExpr unaryExpr) {
this.unaryExpr = unaryExpr;
return this;
}

Optional<UnaryExpr> getUnaryExpr() {
return Optional.ofNullable(unaryExpr);
}
}

ReProcessRHSPhase(MvelCompilerContext mvelCompilerContext) {
this.mvelCompilerContext = mvelCompilerContext;
}

public Optional<TypedExpression> invoke(TypedExpression rhs, TypedExpression lhs) {
this.lhs = lhs;
return Optional.ofNullable(rhs).flatMap(r -> r.toJavaExpression().accept(this, null));
return Optional.ofNullable(rhs).flatMap(r -> r.toJavaExpression().accept(this, new Context()));
}

@Override
public Optional<TypedExpression> defaultMethod(Node n, Void context) {
public Optional<TypedExpression> defaultMethod(Node n, ReProcessRHSPhase.Context context) {
return Optional.empty();
}

@Override
public Optional<TypedExpression> visit(BinaryExpr n, Void arg) {
return convertWhenLHSISBigDecimal(() -> new UnalteredTypedExpression(n));
public Optional<TypedExpression> visit(UnaryExpr n, ReProcessRHSPhase.Context context) {
return n.getExpression().accept(this, context.withUnaryExpr(n));
}

@Override
public Optional<TypedExpression> visit(BinaryExpr n, ReProcessRHSPhase.Context context) {
return convertWhenLHSISBigDecimal(() -> new UnalteredTypedExpression(n), context);
}

@Override
public Optional<TypedExpression> visit(IntegerLiteralExpr n, Void arg) {
return convertWhenLHSISBigDecimal(() -> new IntegerLiteralExpressionT(n));
public Optional<TypedExpression> visit(IntegerLiteralExpr n, ReProcessRHSPhase.Context context) {
return convertWhenLHSISBigDecimal(() -> new IntegerLiteralExpressionT(n), context);
}

@Override
public Optional<TypedExpression> visit(LongLiteralExpr n, Void arg) {
return convertWhenLHSISBigDecimal(() -> new LongLiteralExpressionT(n));
public Optional<TypedExpression> visit(LongLiteralExpr n, ReProcessRHSPhase.Context context) {
return convertWhenLHSISBigDecimal(() -> new LongLiteralExpressionT(n), context);
}

@Override
public Optional<TypedExpression> visit(NameExpr n, Void arg) {
public Optional<TypedExpression> visit(NameExpr n, ReProcessRHSPhase.Context context) {
if(mvelCompilerContext
.findDeclarations(n.toString())
.filter(d -> d.getClazz() != BigDecimal.class)
.isPresent()) { // avoid wrapping BigDecimal declarations
return convertWhenLHSISBigDecimal(() -> new UnalteredTypedExpression(n));
return convertWhenLHSISBigDecimal(() -> new UnalteredTypedExpression(n), context);
} else {
return Optional.empty();
}
}

private Optional<TypedExpression> convertWhenLHSISBigDecimal(Supplier<TypedExpression> conversionFunction) {
private Optional<TypedExpression> convertWhenLHSISBigDecimal(Supplier<TypedExpression> conversionFunction, ReProcessRHSPhase.Context context) {
return lhs.getType()
.filter(BigDecimal.class::equals)
.flatMap(t -> Optional.of(new BigDecimalConvertedExprT(conversionFunction.get())));
.flatMap(t -> Optional.of(new BigDecimalConvertedExprT(conversionFunction.get(), context.getUnaryExpr())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.UnaryExpr;

import static com.github.javaparser.ast.NodeList.nodeList;

Expand All @@ -32,8 +33,15 @@ public class BigDecimalConvertedExprT implements TypedExpression {
private final TypedExpression value;
private final Type type = BigDecimal.class;

private final Optional<UnaryExpr> unaryExpr;

public BigDecimalConvertedExprT(TypedExpression value) {
this(value, Optional.empty());
}

public BigDecimalConvertedExprT(TypedExpression value, Optional<UnaryExpr> unaryExpr) {
this.value = value;
this.unaryExpr = unaryExpr;
}

@Override
Expand All @@ -43,10 +51,9 @@ public Optional<Type> getType() {

@Override
public Node toJavaExpression() {

return new ObjectCreationExpr(null,
StaticJavaParser.parseClassOrInterfaceType(type.getTypeName()),
nodeList((Expression) value.toJavaExpression()));
Expression expr = (Expression) value.toJavaExpression();
Expression arg = unaryExpr.map(u -> (Expression) new UnaryExpr(expr, u.getOperator())).orElse(expr);
return new ObjectCreationExpr(null, StaticJavaParser.parseClassOrInterfaceType(type.getTypeName()), nodeList(arg));
}

@Override
Expand Down

0 comments on commit 03d1f40

Please sign in to comment.