Skip to content

Commit

Permalink
Add baseline tests for constant folding applied before subexpression …
Browse files Browse the repository at this point in the history
…optimization

PiperOrigin-RevId: 610857243
  • Loading branch information
l46kok authored and copybara-github committed Mar 5, 2024
1 parent 8ca4ed4 commit 9599835
Show file tree
Hide file tree
Showing 18 changed files with 12,156 additions and 13,676 deletions.
15 changes: 12 additions & 3 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ public CelAbstractSyntaxTree replaceSubtree(
CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) {
return replaceSubtreeWithNewAst(
ast,
CelAbstractSyntaxTree.newParsedAst(newExpr, CelSource.newBuilder().build()),
CelAbstractSyntaxTree.newParsedAst(
newExpr,
// Copy the macro call information to the new AST such that macro call map can be
// normalized post-replacement.
CelSource.newBuilder().addAllMacroCalls(ast.getSource().getMacroCalls()).build()),
exprIdToReplace);
}

Expand Down Expand Up @@ -571,7 +575,11 @@ private static CelSource combine(CelSource celSource1, CelSource celSource2) {
macroMap.putAll(celSource1.getMacroCalls());
macroMap.putAll(celSource2.getMacroCalls());

return CelSource.newBuilder().addAllMacroCalls(macroMap.buildOrThrow()).build();
return CelSource.newBuilder()
.addAllExtensions(celSource1.getExtensions())
.addAllExtensions(celSource2.getExtensions())
.addAllMacroCalls(macroMap.buildOrThrow())
.build();
}

/**
Expand All @@ -589,7 +597,8 @@ private CelAbstractSyntaxTree stabilizeAst(CelAbstractSyntaxTree ast, long seedE
return CelAbstractSyntaxTree.newParsedAst(newExprBuilder.build(), ast.getSource());
}

CelSource.Builder sourceBuilder = CelSource.newBuilder();
CelSource.Builder sourceBuilder =
CelSource.newBuilder().addAllExtensions(ast.getSource().getExtensions());
// Update the macro call IDs and their call IDs
for (Entry<Long, CelExpr> macroCall : ast.getSource().getMacroCalls().entrySet()) {
long macroId = macroCall.getKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,21 @@ public OptimizationResult optimize(CelNavigableAst navigableAst, Cel cel)
if (iterCount >= constantFoldingOptions.maxIterationLimit()) {
throw new IllegalStateException("Max iteration count reached.");
}
Optional<CelExpr> foldableExpr =
Optional<CelNavigableExpr> foldableExpr =
navigableAst
.getRoot()
.allNodes()
.filter(ConstantFoldingOptimizer::canFold)
.map(CelNavigableExpr::expr)
.filter(expr -> !visitedExprs.contains(expr))
.filter(node -> !visitedExprs.contains(node.expr()))
.findAny();
if (!foldableExpr.isPresent()) {
break;
}
visitedExprs.add(foldableExpr.get());
visitedExprs.add(foldableExpr.get().expr());

Optional<CelAbstractSyntaxTree> mutatedAst;
// Attempt to prune if it is a non-strict call
mutatedAst = maybePruneBranches(navigableAst.getAst(), foldableExpr.get());
mutatedAst = maybePruneBranches(navigableAst.getAst(), foldableExpr.get().expr());
if (!mutatedAst.isPresent()) {
// Evaluate the call then fold
mutatedAst = maybeFold(cel, navigableAst.getAst(), foldableExpr.get());
Expand Down Expand Up @@ -150,7 +149,7 @@ private static boolean canFold(CelNavigableExpr navigableExpr) {
}

if (functionName.equals(Operator.IN.getFunction())) {
return true;
return canFoldInOperator(navigableExpr);
}

// Default case: all call arguments must be constants. If the argument is a container (ex:
Expand All @@ -166,6 +165,30 @@ private static boolean canFold(CelNavigableExpr navigableExpr) {
}
}

private static boolean canFoldInOperator(CelNavigableExpr navigableExpr) {
ImmutableList<CelNavigableExpr> allIdents =
navigableExpr
.allNodes()
.filter(node -> node.getKind().equals(Kind.IDENT))
.collect(toImmutableList());
for (CelNavigableExpr identNode : allIdents) {
CelNavigableExpr parent = identNode.parent().orElse(null);
while (parent != null) {
if (parent.getKind().equals(Kind.COMPREHENSION)) {
if (parent.expr().comprehension().accuVar().equals(identNode.expr().ident().name())) {
// Prevent folding a subexpression if it contains a variable declared by a
// comprehension. The subexpression cannot be compiled without the full context of the
// surrounding comprehension.
return false;
}
}
parent = parent.parent().orElse(null);
}
}

return true;
}

private static boolean areChildrenArgConstant(CelNavigableExpr expr) {
if (expr.getKind().equals(Kind.CONSTANT)) {
return true;
Expand Down Expand Up @@ -195,10 +218,10 @@ private static boolean isNestedComprehension(CelNavigableExpr expr) {
}

private Optional<CelAbstractSyntaxTree> maybeFold(
Cel cel, CelAbstractSyntaxTree ast, CelExpr expr) throws CelOptimizationException {
Cel cel, CelAbstractSyntaxTree ast, CelNavigableExpr node) throws CelOptimizationException {
Object result;
try {
result = CelExprUtil.evaluateExpr(cel, expr);
result = CelExprUtil.evaluateExpr(cel, node.expr());
} catch (CelValidationException | CelEvaluationException e) {
throw new CelOptimizationException(
"Constant folding failure. Failed to evaluate subtree due to: " + e.getMessage(), e);
Expand All @@ -209,11 +232,11 @@ private Optional<CelAbstractSyntaxTree> maybeFold(
// ex2: optional.ofNonZeroValue(5) -> optional.of(5)
if (result instanceof Optional<?>) {
Optional<?> optResult = ((Optional<?>) result);
return maybeRewriteOptional(optResult, ast, expr);
return maybeRewriteOptional(optResult, ast, node.expr());
}

return maybeAdaptEvaluatedResult(result)
.map(celExpr -> mutableAst.replaceSubtree(ast, celExpr, expr.id()));
.map(celExpr -> mutableAst.replaceSubtree(ast, celExpr, node.id()));
}

private Optional<CelExpr> maybeAdaptEvaluatedResult(Object result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import dev.cel.bundle.Cel;
Expand Down Expand Up @@ -58,6 +57,7 @@
import dev.cel.parser.Operator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -509,10 +509,11 @@ private Optional<CelNavigableExpr> findCseCandidateWithRecursionDepth(
ImmutableList<CelNavigableExpr> allNodes =
CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes(TraversalOrder.POST_ORDER)
.allNodes(TraversalOrder.PRE_ORDER)
.filter(this::canEliminate)
.filter(node -> node.height() <= recursionLimit)
.filter(node -> !areSemanticallyEqual(ast.getExpr(), node.expr()))
.sorted(Comparator.comparingInt(CelNavigableExpr::height).reversed())
.collect(toImmutableList());

if (allNodes.isEmpty()) {
Expand All @@ -523,9 +524,23 @@ private Optional<CelNavigableExpr> findCseCandidateWithRecursionDepth(
if (commonSubexpr.isPresent()) {
return commonSubexpr;
}

// If there's no common subexpr, just return the one with the highest height that's still below
// the recursion limit.
return Optional.of(Iterables.getLast(allNodes));
// the recursion limit, but only if it actually needs to be extracted due to exceeding the
// recursion limit.
boolean astHasMoreExtractableSubexprs =
CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.height() > recursionLimit)
.anyMatch(this::canEliminate);
if (astHasMoreExtractableSubexprs) {
return Optional.of(allNodes.get(0));
}

// The height of the remaining subexpression is already below the recursion limit. No need to
// extract.
return Optional.empty();
}

private Optional<CelNavigableExpr> findCseCandidateWithCommonSubexpr(
Expand Down Expand Up @@ -705,13 +720,15 @@ public abstract static class Builder {
* <p>Note that expressions containing no common subexpressions may become a candidate for
* extraction to satisfy the max depth requirement.
*
* <p>This is a no-op if {@link #enableCelBlock} is set to false, or the configured value is
* less than 1.
* <p>This is a no-op if {@link #enableCelBlock} is set to false, the configured value is less
* than 1, or no subexpression needs to be extracted because the entire expression is already
* under the designated limit.
*
* <p>Examples:
*
* <ol>
* <li>a.b.c with depth 1 -> cel.@block([x.b, @index0.c], @index1)
* <li>a.b.c with depth 3 -> a.b.c
* <li>a.b + a.b.c.d with depth 3 -> cel.@block([a.b, @index0.c.d], @index0 + @index1)
* </ol>
*
Expand Down
31 changes: 30 additions & 1 deletion optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public void mutableAst_macro_sourceMacroCallsPopulated() throws Exception {
}

@Test
public void mutableAst_astContainsTaggedExtension_retained() throws Exception {
public void replaceSubtree_astContainsTaggedExtension_retained() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("has(TestAllTypes{}.single_int32)").getAst();
Extension extension = Extension.create("test", Version.of(1, 1));
CelSource celSource = ast.getSource().toBuilder().addAllExtensions(extension).build();
Expand All @@ -137,6 +137,35 @@ public void mutableAst_astContainsTaggedExtension_retained() throws Exception {
assertThat(mutatedAst.getSource().getExtensions()).containsExactly(extension);
}

@Test
public void replaceSubtreeWithNewAst_astsContainTaggedExtension_retained() throws Exception {
// Setup first AST with a test extension
CelAbstractSyntaxTree ast = CEL.compile("has(TestAllTypes{}.single_int32)").getAst();
Extension extension = Extension.create("test", Version.of(1, 1));
ast =
CelAbstractSyntaxTree.newCheckedAst(
ast.getExpr(),
ast.getSource().toBuilder().addAllExtensions(extension).build(),
ast.getReferenceMap(),
ast.getTypeMap());
// Setup second AST with another test extension
CelAbstractSyntaxTree astToReplaceWith = CEL.compile("cel.bind(a, true, a)").getAst();
Extension extension2 = Extension.create("test2", Version.of(2, 2));
astToReplaceWith =
CelAbstractSyntaxTree.newCheckedAst(
astToReplaceWith.getExpr(),
astToReplaceWith.getSource().toBuilder().addAllExtensions(extension2).build(),
astToReplaceWith.getReferenceMap(),
astToReplaceWith.getTypeMap());

// Mutate the original AST with the new AST at the root
CelAbstractSyntaxTree mutatedAst =
MUTABLE_AST.replaceSubtreeWithNewAst(ast, astToReplaceWith, ast.getExpr().id());

// Expect that both the extensions are merged
assertThat(mutatedAst.getSource().getExtensions()).containsExactly(extension, extension2);
}

@Test
@TestParameters("{source: '[1].exists(x, x > 0)', expectedMacroCallSize: 1}")
@TestParameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import dev.cel.common.types.ListType;
import dev.cel.common.types.MapType;
import dev.cel.common.types.SimpleType;
import dev.cel.extensions.CelExtensions;
import dev.cel.extensions.CelOptionalLibrary;
import dev.cel.optimizer.CelOptimizationException;
import dev.cel.optimizer.CelOptimizer;
Expand All @@ -48,7 +49,7 @@ public class ConstantFoldingOptimizerTest {
.addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING))
.addMessageTypes(TestAllTypes.getDescriptor())
.setContainer("dev.cel.testing.testdata.proto3")
.addCompilerLibraries(CelOptionalLibrary.INSTANCE)
.addCompilerLibraries(CelExtensions.bindings(), CelOptionalLibrary.INSTANCE)
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
.build();

Expand Down Expand Up @@ -159,6 +160,8 @@ public class ConstantFoldingOptimizerTest {
"{source: '{\"a\": dyn([1, 2]), \"b\": x}', expected: '{\"a\": [1, 2], \"b\": x}'}")
@TestParameters("{source: 'map_var[?\"key\"]', expected: 'map_var[?\"key\"]'}")
@TestParameters("{source: '\"abc\" in list_var', expected: '\"abc\" in list_var'}")
@TestParameters(
"{source: 'cel.bind(r0, [1, 2, 3], cel.bind(r1, 1 in r0, r1))', expected: 'true'}")
// TODO: Support folding lists with mixed types. This requires mutable lists.
// @TestParameters("{source: 'dyn([1]) + [1.0]'}")
public void constantFold_success(String source, String expected) throws Exception {
Expand Down Expand Up @@ -198,6 +201,10 @@ public void constantFold_success(String source, String expected) throws Exceptio
@TestParameters(
"{source: '[{}, {\"a\": 1}, {\"b\": 2}].filter(m, has(x.a))', expected:"
+ " '[{}, {\"a\": 1}, {\"b\": 2}].filter(m, has(x.a))'}")
@TestParameters(
"{source: 'cel.bind(r0, [1, 2, 3], cel.bind(r1, 1 in r0 && 2 in x, r1))', expected:"
+ " 'cel.bind(r0, [1, 2, 3], cel.bind(r1, 1 in r0 && 2 in x, r1))'}")
@TestParameters("{source: 'false ? false : cel.bind(a, x, a)', expected: 'cel.bind(a, x, a)'}")
public void constantFold_macros_macroCallMetadataPopulated(String source, String expected)
throws Exception {
Cel cel =
Expand All @@ -207,7 +214,7 @@ public void constantFold_macros_macroCallMetadataPopulated(String source, String
.addMessageTypes(TestAllTypes.getDescriptor())
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.setOptions(CelOptions.current().populateMacroCalls(true).build())
.addCompilerLibraries(CelOptionalLibrary.INSTANCE)
.addCompilerLibraries(CelExtensions.bindings(), CelOptionalLibrary.INSTANCE)
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
.build();
CelOptimizer celOptimizer =
Expand Down Expand Up @@ -241,6 +248,8 @@ public void constantFold_macros_macroCallMetadataPopulated(String source, String
@TestParameters(
"{source: '[{}, {\"a\": 1}, {\"b\": 2}].filter(m, has({\"a\": true}.a)) == "
+ " [{}, {\"a\": 1}, {\"b\": 2}]'}")
@TestParameters("{source: 'cel.bind(r0, [1, 2, 3], cel.bind(r1, 1 in r0 && 2 in r0, r1))'}")
@TestParameters("{source: 'false ? false : cel.bind(a, true, a)'}")
public void constantFold_macros_withoutMacroCallMetadata(String source) throws Exception {
Cel cel =
CelFactory.standardCelBuilder()
Expand All @@ -249,7 +258,7 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E
.addMessageTypes(TestAllTypes.getDescriptor())
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.setOptions(CelOptions.current().populateMacroCalls(false).build())
.addCompilerLibraries(CelOptionalLibrary.INSTANCE)
.addCompilerLibraries(CelExtensions.bindings(), CelOptionalLibrary.INSTANCE)
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
.build();
CelOptimizer celOptimizer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public void allOptimizers_producesSameEvaluationResult(
CEL.createProgram(ast)
.eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)));

CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.celOptimizer.optimize(ast);
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);

Object optimizedEvalResult =
CEL.createProgram(optimizedAst)
Expand All @@ -119,7 +119,39 @@ public void subexpression_unparsed() throws Exception {
boolean resultPrinted = false;
for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) {
String optimizerName = cseTestOptimizer.name();
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.celOptimizer.optimize(ast);
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);
if (!resultPrinted) {
Object optimizedEvalResult =
CEL.createProgram(optimizedAst)
.eval(
ImmutableMap.of(
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)));
testOutput().println("Result: " + optimizedEvalResult);
resultPrinted = true;
}
try {
testOutput().printf("[%s]: %s", optimizerName, CEL_UNPARSER.unparse(optimizedAst));
} catch (RuntimeException e) {
testOutput().printf("[%s]: Unparse Error: %s", optimizerName, e);
}
testOutput().println();
}
testOutput().println();
}
}

@Test
public void constfold_before_subexpression_unparsed() throws Exception {
for (CseTestCase cseTestCase : CseTestCase.values()) {
testOutput().println("Test case: " + cseTestCase.name());
testOutput().println("Source: " + cseTestCase.source);
testOutput().println("=====>");
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
boolean resultPrinted = false;
for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) {
String optimizerName = cseTestOptimizer.name();
CelAbstractSyntaxTree optimizedAst =
cseTestOptimizer.cseWithConstFoldingOptimizer.optimize(ast);
if (!resultPrinted) {
Object optimizedEvalResult =
CEL.createProgram(optimizedAst)
Expand Down Expand Up @@ -149,7 +181,7 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer)
testOutput().println("Source: " + cseTestCase.source);
testOutput().println("=====>");
CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst();
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.celOptimizer.optimize(ast);
CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast);
testOutput().println(optimizedAst.getExpr());
}
}
Expand Down Expand Up @@ -339,10 +371,17 @@ private enum CseTestOptimizer {
BLOCK_RECURSION_DEPTH_9(
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(9).build());

private final CelOptimizer celOptimizer;
private final CelOptimizer cseOptimizer;
private final CelOptimizer cseWithConstFoldingOptimizer;

CseTestOptimizer(SubexpressionOptimizerOptions option) {
this.celOptimizer = newCseOptimizer(CEL, option);
this.cseOptimizer = newCseOptimizer(CEL, option);
this.cseWithConstFoldingOptimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
.addAstOptimizers(
ConstantFoldingOptimizer.getInstance(),
SubexpressionOptimizer.newInstance(option))
.build();
}
}

Expand Down
Loading

0 comments on commit 9599835

Please sign in to comment.