From 629f85b5ad8b5db0ac113eea864f74d20616e4dd Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 15 Feb 2024 17:20:32 -0800 Subject: [PATCH] Add mangled comprehension variables as identifier declaration to the environment PiperOrigin-RevId: 607507168 --- .../java/dev/cel/optimizer/MutableAst.java | 26 ++++++- .../optimizers/SubexpressionOptimizer.java | 59 ++++++++++++--- .../dev/cel/optimizer/MutableAstTest.java | 9 ++- .../SubexpressionOptimizerTest.java | 72 ++++++++++++++++--- 4 files changed, 143 insertions(+), 23 deletions(-) diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index 5798006c..ae2f328e 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -23,6 +23,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelSource; @@ -200,10 +201,11 @@ public CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) * @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will * produce @c0, @c1, @c2... as new names. */ - public CelAbstractSyntaxTree mangleComprehensionIdentifierNames( + public MangledComprehensionAst mangleComprehensionIdentifierNames( CelAbstractSyntaxTree ast, String newIdentPrefix) { int iterCount; CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast); + ImmutableSet.Builder mangledComprehensionIdents = ImmutableSet.builder(); for (iterCount = 0; iterCount < iterationLimit; iterCount++) { CelNavigableExpr comprehensionNode = newNavigableAst @@ -223,6 +225,7 @@ public CelAbstractSyntaxTree mangleComprehensionIdentifierNames( String iterVar = comprehensionExpr.comprehension().iterVar(); int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode); String mangledVarName = newIdentPrefix + comprehensionNestingLevel; + mangledComprehensionIdents.add(mangledVarName); CelExpr.Builder mutatedComprehensionExpr = mangleIdentsInComprehensionExpr( @@ -251,7 +254,7 @@ public CelAbstractSyntaxTree mangleComprehensionIdentifierNames( throw new IllegalStateException("Max iteration count reached."); } - return newNavigableAst.getAst(); + return MangledComprehensionAst.of(newNavigableAst.getAst(), mangledComprehensionIdents.build()); } /** @@ -575,6 +578,25 @@ private static int countComprehensionNestingLevel(CelNavigableExpr comprehension return nestedLevel; } + /** + * Intermediate value class to store the mangled identifiers for iteration variable in the + * comprehension. + */ + @AutoValue + public abstract static class MangledComprehensionAst { + + /** AST after the iteration variables have been mangled. */ + public abstract CelAbstractSyntaxTree ast(); + + /** Set of identifiers with the iteration variable mangled. */ + public abstract ImmutableSet mangledComprehensionIdents(); + + private static MangledComprehensionAst of( + CelAbstractSyntaxTree ast, ImmutableSet mangledComprehensionIdents) { + return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionIdents); + } + } + /** * Intermediate value class to store the generated CelExpr for the bind macro and the macro call * information. diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index b62267bd..28b5e02a 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -33,6 +33,7 @@ import dev.cel.common.CelSource.Extension.Component; import dev.cel.common.CelSource.Extension.Version; import dev.cel.common.CelValidationException; +import dev.cel.common.CelVarDecl; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelCall; import dev.cel.common.ast.CelExpr.CelIdent; @@ -45,6 +46,7 @@ import dev.cel.common.types.SimpleType; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.MutableAst; +import dev.cel.optimizer.MutableAst.MangledComprehensionAst; import dev.cel.parser.Operator; import java.util.ArrayList; import java.util.HashMap; @@ -90,10 +92,8 @@ public class SubexpressionOptimizer implements CelAstOptimizer { stream(Operator.values()).map(Operator::getFunction), stream(Standard.Function.values()).map(Standard.Function::getFunction)) .collect(toImmutableSet()); - private static final Extension CEL_BLOCK_AST_EXTENSION_TAG = Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME); - private final SubexpressionOptimizerOptions cseOptions; private final MutableAst mutableAst; @@ -125,10 +125,13 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock( // Retain the original expected result type, so that it can be reset in celBuilder at the end of // the optimization pass. CelType resultType = navigableAst.getAst().getResultType(); - CelAbstractSyntaxTree astToModify = + MangledComprehensionAst mangledComprehensionAst = mutableAst.mangleComprehensionIdentifierNames( navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX); + CelAbstractSyntaxTree astToModify = mangledComprehensionAst.ast(); CelSource sourceToModify = astToModify.getSource(); + ImmutableSet mangledIdentDecls = + newMangledIdentDecls(celBuilder, mangledComprehensionAst); int blockIdentifierIndex = 0; int iterCount; @@ -187,6 +190,9 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock( return astToModify; } + // Add all mangled comprehension identifiers to the environment, so that the subexpressions can + // retain context to them. + celBuilder.addVarDeclarations(mangledIdentDecls); // Type-check all sub-expressions then add them as block identifiers to the CEL environment addBlockIdentsToEnv(celBuilder, subexpressions); @@ -254,10 +260,47 @@ private static void addBlockIdentsToEnv(CelBuilder celBuilder, List sub } } + private static ImmutableSet newMangledIdentDecls( + CelBuilder celBuilder, MangledComprehensionAst mangledComprehensionAst) { + if (mangledComprehensionAst.mangledComprehensionIdents().isEmpty()) { + return ImmutableSet.of(); + } + CelAbstractSyntaxTree ast = mangledComprehensionAst.ast(); + try { + ast = celBuilder.build().check(ast).getAst(); + } catch (CelValidationException e) { + throw new IllegalStateException("Failed to type-check mangled AST.", e); + } + + ImmutableSet.Builder mangledVarDecls = ImmutableSet.builder(); + for (String ident : mangledComprehensionAst.mangledComprehensionIdents()) { + CelExpr mangledIdentExpr = + CelNavigableAst.fromAst(ast) + .getRoot() + .allNodes() + .filter(node -> node.getKind().equals(Kind.IDENT)) + .map(CelNavigableExpr::expr) + .filter(expr -> expr.ident().name().equals(ident)) + .findAny() + .orElse(null); + if (mangledIdentExpr == null) { + break; + } + + CelType mangledIdentType = + ast.getType(mangledIdentExpr.id()).orElseThrow(() -> new NoSuchElementException("?")); + mangledVarDecls.add(CelVarDecl.newVarDeclaration(ident, mangledIdentType)); + } + + return mangledVarDecls.build(); + } + private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst) { CelAbstractSyntaxTree astToModify = - mutableAst.mangleComprehensionIdentifierNames( - navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX); + mutableAst + .mangleComprehensionIdentifierNames( + navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX) + .ast(); CelSource sourceToModify = astToModify.getSource(); int bindIdentifierIndex = 0; @@ -526,9 +569,9 @@ public abstract static class Builder { /** * Rewrites the optimized AST using cel.@block call instead of cascaded cel.bind macros, aimed - * to produce a more compact AST. {@link com.google.api.expr.SourceInfo.Extension} field will - * be populated in the AST to inform that special runtime support is required to evaluate the - * optimized expression. + * to produce a more compact AST. {@link CelSource.Extension} field will be populated in the + * AST to inform that special runtime support is required to evaluate the optimized + * expression. */ public abstract Builder enableCelBlock(boolean value); diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java index bd9d122c..9c32520b 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -681,7 +681,8 @@ public void comprehension_replaceLoopStep() throws Exception { public void mangleComprehensionVariable_singleMacro() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst(); - CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c"); + CelAbstractSyntaxTree mangledAst = + MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast(); assertThat(mangledAst.getExpr().toString()) .isEqualTo( @@ -741,7 +742,8 @@ public void mangleComprehensionVariable_singleMacro() throws Exception { public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[x].exists(x, [x].exists(x, x == 1))").getAst(); - CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c"); + CelAbstractSyntaxTree mangledAst = + MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast(); assertThat(mangledAst.getExpr().toString()) .isEqualTo( @@ -858,7 +860,8 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw public void mangleComprehensionVariable_hasMacro_noOp() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("has(msg.single_int64)").getAst(); - CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c"); + CelAbstractSyntaxTree mangledAst = + MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast(); assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("has(msg.single_int64)"); assertThat( diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index ae8bc7f1..cf7cdd6d 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -91,6 +91,8 @@ public class SubexpressionOptimizerTest { CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) // Similarly, this is a test only decl (index0 -> @index0) .addVarDeclarations( + CelVarDecl.newVarDeclaration("c0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("c1", SimpleType.DYN), CelVarDecl.newVarDeclaration("index0", SimpleType.DYN), CelVarDecl.newVarDeclaration("index1", SimpleType.DYN), CelVarDecl.newVarDeclaration("index2", SimpleType.DYN), @@ -506,8 +508,9 @@ private enum CseTestCase { "size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2", "size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))" + ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2", - "Currently Unsupported"), // TODO: Handle comprehension variables that fall - // outside the cel.block scope + "cel.@block([@c1 + @c1, @c0 + @c0], " + + "size([\"foo\", \"bar\"].map(@c1, [@index0, @index0])" + + ".map(@c0, [@index1, @index1])) == 2)"), PRESENCE_TEST( "has({'a': true}.a) && {'a':true}['a']", "cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])", @@ -683,10 +686,6 @@ public void cse_withCelBind_macroMapUnpopulated(@TestParameter CseTestCase testC @Test public void cse_withCelBlock_macroMapPopulated(@TestParameter CseTestCase testCase) throws Exception { - if (testCase.equals(CseTestCase.MACRO_SHADOWED_VARIABLE_2)) { - // TODO: Handle comprehension variables that fall outside the cel.block scope - return; - } CelOptimizer celOptimizer = newCseOptimizer( SubexpressionOptimizerOptions.newBuilder() @@ -709,10 +708,6 @@ public void cse_withCelBlock_macroMapPopulated(@TestParameter CseTestCase testCa @Test public void cse_withCelBlock_macroMapUnpopulated(@TestParameter CseTestCase testCase) throws Exception { - if (testCase.equals(CseTestCase.MACRO_SHADOWED_VARIABLE_2)) { - // TODO: Handle comprehension variables that fall outside the cel.block scope - return; - } CelOptimizer celOptimizer = newCseOptimizer( SubexpressionOptimizerOptions.newBuilder() @@ -732,6 +727,32 @@ public void cse_withCelBlock_macroMapUnpopulated(@TestParameter CseTestCase test .isEqualTo(true); } + @Test + public void celBlock_nestedComprehension_iterVarReferencedAcrossComprehensions() + throws Exception { + String nestedComprehension = + "[\"foo\"].map(x, [[\"bar\"], [x + x, x + x]] + [\"bar\"].map(y, [x + y, [\"baz\"].map(z," + + " [x + y + z, x + y, x + y + z])])) == [[[\"bar\"], [\"foofoo\", \"foofoo\"]," + + " [\"foobar\", [[\"foobarbaz\", \"foobar\", \"foobarbaz\"]]]]]"; + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(nestedComprehension).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo( + "cel.@block([@c0 + @c0, [\"bar\"], @c0 + @c1, @index2 + @c2], [\"foo\"].map(@c0," + + " [@index1, [@index0, @index0]] + @index1.map(@c1, [@index2, [\"baz\"].map(@c2," + + " [@index3, @index2, @index3])])) == [[@index1, [\"foofoo\", \"foofoo\"]," + + " [\"foobar\", [[\"foobarbaz\", \"foobar\", \"foobarbaz\"]]]]])"); + } + @Test public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception { Cel cel = newCelBuilder().setResultType(SimpleType.BOOL).build(); @@ -1264,6 +1285,37 @@ public void lazyEval_multipleBlockIndices_cascaded() throws Exception { assertThat(invocation.get()).isEqualTo(1); } + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyEval_nestedComprehension_indexReferencedInNestedScopes() throws Exception { + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + // Equivalent of [true, false, true].map(c0, [c0].map(c1, [c0, c1, true])) + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions( + "cel.block([c0, c1, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [index0," + + " index1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true," + + " true, true]]]"); + + boolean result = (boolean) celRuntime.createProgram(ast).eval(); + + assertThat(result).isTrue(); + // Even though the function get_true() is referenced across different comprehension scopes, + // it still gets memoized only once. + assertThat(invocation.get()).isEqualTo(1); + } + @Test @TestParameters("{source: 'cel.block([])'}") @TestParameters("{source: 'cel.block([1])'}")