Skip to content

Commit

Permalink
Add an option for controlling short-circuiting behavior for logical o…
Browse files Browse the repository at this point in the history
…perators.

PiperOrigin-RevId: 613276249
  • Loading branch information
l46kok authored and copybara-github committed Mar 6, 2024
1 parent 366e749 commit bb64ec7
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 12 deletions.
14 changes: 14 additions & 0 deletions common/src/main/java/dev/cel/common/CelOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ public abstract class CelOptions {

public abstract boolean disableCelStandardEquality();

public abstract boolean enableShortCircuiting();

public abstract boolean enableRegexPartialMatch();

public abstract boolean enableUnsignedComparisonAndArithmeticIsUnsigned();
Expand Down Expand Up @@ -170,6 +172,7 @@ public static Builder newBuilder() {
.enableNamespacedDeclarations(true)
// Evaluation options
.disableCelStandardEquality(true)
.enableShortCircuiting(true)
.enableRegexPartialMatch(false)
.enableUnsignedComparisonAndArithmeticIsUnsigned(false)
.enableUnsignedLongs(false)
Expand Down Expand Up @@ -364,6 +367,17 @@ public abstract static class Builder {
*/
public abstract Builder disableCelStandardEquality(boolean value);

/**
* Enable short-circuiting of the logical operator evaluation. If enabled, AND, OR, and TERNARY
* do not evaluate the entire expression once the resulting value is known from the left-hand
* side.
*
* <p>This option is enabled by default. In most cases, this should not be disabled except for
* debugging purposes or collecting results for all evaluated branches through {@link
* dev.cel.runtime.CelEvaluationListener}.
*/
public abstract Builder enableShortCircuiting(boolean value);

/**
* Treat regex {@code matches} calls as substring (unanchored) match patterns.
*
Expand Down
47 changes: 37 additions & 10 deletions runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,22 @@ private Optional<CelAttribute> maybeContainerIndexAttribute(
private IntermediateResult evalConditional(ExecutionFrame frame, CelCall callExpr)
throws InterpreterException {
IntermediateResult condition = evalBooleanStrict(frame, callExpr.args().get(0));
if (isUnknownValue(condition.value())) {
return condition;
}
if ((boolean) condition.value()) {
return evalInternal(frame, callExpr.args().get(1));
if (celOptions.enableShortCircuiting()) {
if (isUnknownValue(condition.value())) {
return condition;
}
if ((boolean) condition.value()) {
return evalInternal(frame, callExpr.args().get(1));
}
return evalInternal(frame, callExpr.args().get(2));
} else {
IntermediateResult lhs = evalInternal(frame, callExpr.args().get(1));
IntermediateResult rhs = evalInternal(frame, callExpr.args().get(2));
if (isUnknownValue(condition.value())) {
return condition;
}
return (boolean) condition.value() ? lhs : rhs;
}
return evalInternal(frame, callExpr.args().get(2));
}

private IntermediateResult mergeBooleanUnknowns(IntermediateResult lhs, IntermediateResult rhs)
Expand All @@ -482,15 +491,33 @@ private IntermediateResult mergeBooleanUnknowns(IntermediateResult lhs, Intermed
InterpreterUtil.shortcircuitUnknownOrThrowable(lhs.value(), rhs.value()));
}

private enum ShortCircuitableOperators {
LOGICAL_OR,
LOGICAL_AND
}

private boolean canShortCircuit(IntermediateResult result, ShortCircuitableOperators operator) {
if (!celOptions.enableShortCircuiting() || !(result.value() instanceof Boolean)) {
return false;
}

Boolean value = (Boolean) result.value();
if (value && operator.equals(ShortCircuitableOperators.LOGICAL_OR)) {
return true;
}

return !value && operator.equals(ShortCircuitableOperators.LOGICAL_AND);
}

private IntermediateResult evalLogicalOr(ExecutionFrame frame, CelCall callExpr)
throws InterpreterException {
IntermediateResult left = evalBooleanNonstrict(frame, callExpr.args().get(0));
if (left.value() instanceof Boolean && (Boolean) left.value()) {
if (canShortCircuit(left, ShortCircuitableOperators.LOGICAL_OR)) {
return left;
}

IntermediateResult right = evalBooleanNonstrict(frame, callExpr.args().get(1));
if (right.value() instanceof Boolean && (Boolean) right.value()) {
if (canShortCircuit(right, ShortCircuitableOperators.LOGICAL_OR)) {
return right;
}

Expand All @@ -505,12 +532,12 @@ private IntermediateResult evalLogicalOr(ExecutionFrame frame, CelCall callExpr)
private IntermediateResult evalLogicalAnd(ExecutionFrame frame, CelCall callExpr)
throws InterpreterException {
IntermediateResult left = evalBooleanNonstrict(frame, callExpr.args().get(0));
if (left.value() instanceof Boolean && !((Boolean) left.value())) {
if (canShortCircuit(left, ShortCircuitableOperators.LOGICAL_AND)) {
return left;
}

IntermediateResult right = evalBooleanNonstrict(frame, callExpr.args().get(1));
if (right.value() instanceof Boolean && !((Boolean) right.value())) {
if (canShortCircuit(right, ShortCircuitableOperators.LOGICAL_AND)) {
return right;
}

Expand Down
141 changes: 139 additions & 2 deletions runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
import com.google.api.expr.v1alpha1.Constant;
import com.google.api.expr.v1alpha1.Expr;
import com.google.api.expr.v1alpha1.Type.PrimitiveType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.protobuf.Any;
import com.google.protobuf.BoolValue;
import com.google.protobuf.ByteString;
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.DynamicMessage;
import com.google.rpc.context.AttributeContext;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import com.google.testing.junit.testparameterinjector.TestParameters;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelFactory;
import dev.cel.common.CelAbstractSyntaxTree;
Expand All @@ -50,9 +53,8 @@
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
@RunWith(TestParameterInjector.class)
public class CelRuntimeTest {

@Test
Expand Down Expand Up @@ -400,6 +402,141 @@ public void trace_withVariableResolver() throws Exception {
assertThat(result).isEqualTo("hello");
}

@Test
public void trace_shortCircuitingDisabled_logicalAndAllBranchesVisited() throws Exception {
ImmutableList.Builder<Boolean> branchResults = ImmutableList.builder();
CelEvaluationListener listener =
(expr, res) -> {
if (expr.constantOrDefault().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE)) {
branchResults.add((Boolean) res);
}
};
Cel cel =
CelFactory.standardCelBuilder()
.setOptions(CelOptions.current().enableShortCircuiting(false).build())
.build();
CelAbstractSyntaxTree ast = cel.compile("false && true && false").getAst();

boolean result = (boolean) cel.createProgram(ast).trace(listener);

assertThat(result).isFalse();
assertThat(branchResults.build()).containsExactly(false, true, false);
}

@Test
public void trace_shortCircuitingDisabled_logicalAndWithUnknowns() throws Exception {
ImmutableList.Builder<Object> branchResults = ImmutableList.builder();
CelEvaluationListener listener =
(expr, res) -> {
if (expr.constantOrDefault().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE)
|| expr.identOrDefault().name().equals("x")) {
branchResults.add(res);
}
};
Cel cel =
CelFactory.standardCelBuilder()
.addVar("x", SimpleType.BOOL)
.setOptions(CelOptions.current().enableShortCircuiting(false).build())
.build();
CelAbstractSyntaxTree ast = cel.compile("false && false && x").getAst();

Object unknownResult = cel.createProgram(ast).trace(listener);

assertThat(InterpreterUtil.isUnknown(unknownResult)).isTrue();
assertThat(branchResults.build()).containsExactly(false, false, unknownResult);
}

@Test
public void trace_shortCircuitingDisabled_logicalOrAllBranchesVisited() throws Exception {
ImmutableList.Builder<Boolean> branchResults = ImmutableList.builder();
CelEvaluationListener listener =
(expr, res) -> {
if (expr.constantOrDefault().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE)) {
branchResults.add((Boolean) res);
}
};
Cel cel =
CelFactory.standardCelBuilder()
.setOptions(CelOptions.current().enableShortCircuiting(false).build())
.build();
CelAbstractSyntaxTree ast = cel.compile("true || false || true").getAst();

boolean result = (boolean) cel.createProgram(ast).trace(listener);

assertThat(result).isTrue();
assertThat(branchResults.build()).containsExactly(true, false, true);
}

@Test
public void trace_shortCircuitingDisabled_logicalOrWithUnknowns() throws Exception {
ImmutableList.Builder<Object> branchResults = ImmutableList.builder();
CelEvaluationListener listener =
(expr, res) -> {
if (expr.constantOrDefault().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE)
|| expr.identOrDefault().name().equals("x")) {
branchResults.add(res);
}
};
Cel cel =
CelFactory.standardCelBuilder()
.addVar("x", SimpleType.BOOL)
.setOptions(CelOptions.current().enableShortCircuiting(false).build())
.build();
CelAbstractSyntaxTree ast = cel.compile("false || false || x").getAst();

Object unknownResult = cel.createProgram(ast).trace(listener);

assertThat(InterpreterUtil.isUnknown(unknownResult)).isTrue();
assertThat(branchResults.build()).containsExactly(false, false, unknownResult);
}

@Test
public void trace_shortCircuitingDisabled_ternaryAllBranchesVisited() throws Exception {
ImmutableList.Builder<Boolean> branchResults = ImmutableList.builder();
CelEvaluationListener listener =
(expr, res) -> {
if (expr.constantOrDefault().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE)) {
branchResults.add((Boolean) res);
}
};
Cel cel =
CelFactory.standardCelBuilder()
.setOptions(CelOptions.current().enableShortCircuiting(false).build())
.build();
CelAbstractSyntaxTree ast = cel.compile("true ? false : true").getAst();

boolean result = (boolean) cel.createProgram(ast).trace(listener);

assertThat(result).isFalse();
assertThat(branchResults.build()).containsExactly(true, false, true);
}

@Test
@TestParameters("{source: 'false ? true : x'}")
@TestParameters("{source: 'true ? x : false'}")
@TestParameters("{source: 'x ? true : false'}")
public void trace_shortCircuitingDisabled_ternaryWithUnknowns(String source) throws Exception {
ImmutableList.Builder<Object> branchResults = ImmutableList.builder();
CelEvaluationListener listener =
(expr, res) -> {
if (expr.constantOrDefault().getKind().equals(CelConstant.Kind.BOOLEAN_VALUE)
|| expr.identOrDefault().name().equals("x")) {
branchResults.add(res);
}
};
Cel cel =
CelFactory.standardCelBuilder()
.addVar("x", SimpleType.BOOL)
.setOptions(CelOptions.current().enableShortCircuiting(false).build())
.build();
CelAbstractSyntaxTree ast = cel.compile(source).getAst();

Object unknownResult = cel.createProgram(ast).trace(listener);

assertThat(InterpreterUtil.isUnknown(unknownResult)).isTrue();
assertThat(branchResults.build()).containsExactly(false, unknownResult, true);
}

@Test
public void standardEnvironmentDisabledForRuntime_throws() throws Exception {
CelCompiler celCompiler =
Expand Down

0 comments on commit bb64ec7

Please sign in to comment.