Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mangled comprehension variables as identifier declaration to the environment #244

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.Immutable;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelSource;
Expand Down Expand Up @@ -200,10 +201,11 @@ public CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast)
* @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will
* produce @c0, @c1, @c2... as new names.
*/
public CelAbstractSyntaxTree mangleComprehensionIdentifierNames(
public MangledComprehensionAst mangleComprehensionIdentifierNames(
CelAbstractSyntaxTree ast, String newIdentPrefix) {
int iterCount;
CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast);
ImmutableSet.Builder<String> mangledComprehensionIdents = ImmutableSet.builder();
for (iterCount = 0; iterCount < iterationLimit; iterCount++) {
CelNavigableExpr comprehensionNode =
newNavigableAst
Expand All @@ -223,6 +225,7 @@ public CelAbstractSyntaxTree mangleComprehensionIdentifierNames(
String iterVar = comprehensionExpr.comprehension().iterVar();
int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode);
String mangledVarName = newIdentPrefix + comprehensionNestingLevel;
mangledComprehensionIdents.add(mangledVarName);

CelExpr.Builder mutatedComprehensionExpr =
mangleIdentsInComprehensionExpr(
Expand Down Expand Up @@ -251,7 +254,7 @@ public CelAbstractSyntaxTree mangleComprehensionIdentifierNames(
throw new IllegalStateException("Max iteration count reached.");
}

return newNavigableAst.getAst();
return MangledComprehensionAst.of(newNavigableAst.getAst(), mangledComprehensionIdents.build());
}

/**
Expand Down Expand Up @@ -575,6 +578,25 @@ private static int countComprehensionNestingLevel(CelNavigableExpr comprehension
return nestedLevel;
}

/**
* Intermediate value class to store the mangled identifiers for iteration variable in the
* comprehension.
*/
@AutoValue
public abstract static class MangledComprehensionAst {

/** AST after the iteration variables have been mangled. */
public abstract CelAbstractSyntaxTree ast();

/** Set of identifiers with the iteration variable mangled. */
public abstract ImmutableSet<String> mangledComprehensionIdents();

private static MangledComprehensionAst of(
CelAbstractSyntaxTree ast, ImmutableSet<String> mangledComprehensionIdents) {
return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionIdents);
}
}

/**
* Intermediate value class to store the generated CelExpr for the bind macro and the macro call
* information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import dev.cel.common.CelSource.Extension.Component;
import dev.cel.common.CelSource.Extension.Version;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelIdent;
Expand All @@ -45,6 +46,7 @@
import dev.cel.common.types.SimpleType;
import dev.cel.optimizer.CelAstOptimizer;
import dev.cel.optimizer.MutableAst;
import dev.cel.optimizer.MutableAst.MangledComprehensionAst;
import dev.cel.parser.Operator;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -90,10 +92,8 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
stream(Operator.values()).map(Operator::getFunction),
stream(Standard.Function.values()).map(Standard.Function::getFunction))
.collect(toImmutableSet());

private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME);

private final SubexpressionOptimizerOptions cseOptions;
private final MutableAst mutableAst;

Expand Down Expand Up @@ -125,10 +125,13 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
// Retain the original expected result type, so that it can be reset in celBuilder at the end of
// the optimization pass.
CelType resultType = navigableAst.getAst().getResultType();
CelAbstractSyntaxTree astToModify =
MangledComprehensionAst mangledComprehensionAst =
mutableAst.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
CelAbstractSyntaxTree astToModify = mangledComprehensionAst.ast();
CelSource sourceToModify = astToModify.getSource();
ImmutableSet<CelVarDecl> mangledIdentDecls =
newMangledIdentDecls(celBuilder, mangledComprehensionAst);

int blockIdentifierIndex = 0;
int iterCount;
Expand Down Expand Up @@ -187,6 +190,9 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
return astToModify;
}

// Add all mangled comprehension identifiers to the environment, so that the subexpressions can
// retain context to them.
celBuilder.addVarDeclarations(mangledIdentDecls);
// Type-check all sub-expressions then add them as block identifiers to the CEL environment
addBlockIdentsToEnv(celBuilder, subexpressions);

Expand Down Expand Up @@ -254,10 +260,47 @@ private static void addBlockIdentsToEnv(CelBuilder celBuilder, List<CelExpr> sub
}
}

private static ImmutableSet<CelVarDecl> newMangledIdentDecls(
CelBuilder celBuilder, MangledComprehensionAst mangledComprehensionAst) {
if (mangledComprehensionAst.mangledComprehensionIdents().isEmpty()) {
return ImmutableSet.of();
}
CelAbstractSyntaxTree ast = mangledComprehensionAst.ast();
try {
ast = celBuilder.build().check(ast).getAst();
} catch (CelValidationException e) {
throw new IllegalStateException("Failed to type-check mangled AST.", e);
}

ImmutableSet.Builder<CelVarDecl> mangledVarDecls = ImmutableSet.builder();
for (String ident : mangledComprehensionAst.mangledComprehensionIdents()) {
CelExpr mangledIdentExpr =
CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes()
.filter(node -> node.getKind().equals(Kind.IDENT))
.map(CelNavigableExpr::expr)
.filter(expr -> expr.ident().name().equals(ident))
.findAny()
.orElse(null);
if (mangledIdentExpr == null) {
break;
}

CelType mangledIdentType =
ast.getType(mangledIdentExpr.id()).orElseThrow(() -> new NoSuchElementException("?"));
mangledVarDecls.add(CelVarDecl.newVarDeclaration(ident, mangledIdentType));
}

return mangledVarDecls.build();
}

private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst) {
CelAbstractSyntaxTree astToModify =
mutableAst.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
mutableAst
.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX)
.ast();
CelSource sourceToModify = astToModify.getSource();

int bindIdentifierIndex = 0;
Expand Down Expand Up @@ -526,9 +569,9 @@ public abstract static class Builder {

/**
* Rewrites the optimized AST using cel.@block call instead of cascaded cel.bind macros, aimed
* to produce a more compact AST. {@link com.google.api.expr.SourceInfo.Extension} field will
* be populated in the AST to inform that special runtime support is required to evaluate the
* optimized expression.
* to produce a more compact AST. {@link CelSource.Extension} field will be populated in the
* AST to inform that special runtime support is required to evaluate the optimized
* expression.
*/
public abstract Builder enableCelBlock(boolean value);

Expand Down
9 changes: 6 additions & 3 deletions optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,8 @@ public void comprehension_replaceLoopStep() throws Exception {
public void mangleComprehensionVariable_singleMacro() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst();

CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c");
CelAbstractSyntaxTree mangledAst =
MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast();

assertThat(mangledAst.getExpr().toString())
.isEqualTo(
Expand Down Expand Up @@ -741,7 +742,8 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("[x].exists(x, [x].exists(x, x == 1))").getAst();

CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c");
CelAbstractSyntaxTree mangledAst =
MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast();

assertThat(mangledAst.getExpr().toString())
.isEqualTo(
Expand Down Expand Up @@ -858,7 +860,8 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
public void mangleComprehensionVariable_hasMacro_noOp() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("has(msg.single_int64)").getAst();

CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c");
CelAbstractSyntaxTree mangledAst =
MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast();

assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("has(msg.single_int64)");
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ public class SubexpressionOptimizerTest {
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
// Similarly, this is a test only decl (index0 -> @index0)
.addVarDeclarations(
CelVarDecl.newVarDeclaration("c0", SimpleType.DYN),
CelVarDecl.newVarDeclaration("c1", SimpleType.DYN),
CelVarDecl.newVarDeclaration("index0", SimpleType.DYN),
CelVarDecl.newVarDeclaration("index1", SimpleType.DYN),
CelVarDecl.newVarDeclaration("index2", SimpleType.DYN),
Expand Down Expand Up @@ -506,8 +508,9 @@ private enum CseTestCase {
"size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2",
"size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))"
+ ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2",
"Currently Unsupported"), // TODO: Handle comprehension variables that fall
// outside the cel.block scope
"cel.@block([@c1 + @c1, @c0 + @c0], "
+ "size([\"foo\", \"bar\"].map(@c1, [@index0, @index0])"
+ ".map(@c0, [@index1, @index1])) == 2)"),
PRESENCE_TEST(
"has({'a': true}.a) && {'a':true}['a']",
"cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])",
Expand Down Expand Up @@ -683,10 +686,6 @@ public void cse_withCelBind_macroMapUnpopulated(@TestParameter CseTestCase testC
@Test
public void cse_withCelBlock_macroMapPopulated(@TestParameter CseTestCase testCase)
throws Exception {
if (testCase.equals(CseTestCase.MACRO_SHADOWED_VARIABLE_2)) {
// TODO: Handle comprehension variables that fall outside the cel.block scope
return;
}
CelOptimizer celOptimizer =
newCseOptimizer(
SubexpressionOptimizerOptions.newBuilder()
Expand All @@ -709,10 +708,6 @@ public void cse_withCelBlock_macroMapPopulated(@TestParameter CseTestCase testCa
@Test
public void cse_withCelBlock_macroMapUnpopulated(@TestParameter CseTestCase testCase)
throws Exception {
if (testCase.equals(CseTestCase.MACRO_SHADOWED_VARIABLE_2)) {
// TODO: Handle comprehension variables that fall outside the cel.block scope
return;
}
CelOptimizer celOptimizer =
newCseOptimizer(
SubexpressionOptimizerOptions.newBuilder()
Expand All @@ -732,6 +727,32 @@ public void cse_withCelBlock_macroMapUnpopulated(@TestParameter CseTestCase test
.isEqualTo(true);
}

@Test
public void celBlock_nestedComprehension_iterVarReferencedAcrossComprehensions()
throws Exception {
String nestedComprehension =
"[\"foo\"].map(x, [[\"bar\"], [x + x, x + x]] + [\"bar\"].map(y, [x + y, [\"baz\"].map(z,"
+ " [x + y + z, x + y, x + y + z])])) == [[[\"bar\"], [\"foofoo\", \"foofoo\"],"
+ " [\"foobar\", [[\"foobarbaz\", \"foobar\", \"foobarbaz\"]]]]]";
CelOptimizer celOptimizer =
newCseOptimizer(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.build());
CelAbstractSyntaxTree ast = CEL.compile(nestedComprehension).getAst();

CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast);

assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true);
assertThat(CEL_UNPARSER.unparse(optimizedAst))
.isEqualTo(
"cel.@block([@c0 + @c0, [\"bar\"], @c0 + @c1, @index2 + @c2], [\"foo\"].map(@c0,"
+ " [@index1, [@index0, @index0]] + @index1.map(@c1, [@index2, [\"baz\"].map(@c2,"
+ " [@index3, @index2, @index3])])) == [[@index1, [\"foofoo\", \"foofoo\"],"
+ " [\"foobar\", [[\"foobarbaz\", \"foobar\", \"foobarbaz\"]]]]])");
}

@Test
public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception {
Cel cel = newCelBuilder().setResultType(SimpleType.BOOL).build();
Expand Down Expand Up @@ -1264,6 +1285,37 @@ public void lazyEval_multipleBlockIndices_cascaded() throws Exception {
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyEval_nestedComprehension_indexReferencedInNestedScopes() throws Exception {
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
// Equivalent of [true, false, true].map(c0, [c0].map(c1, [c0, c1, true]))
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions(
"cel.block([c0, c1, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [index0,"
+ " index1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true,"
+ " true, true]]]");

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
// Even though the function get_true() is referenced across different comprehension scopes,
// it still gets memoized only once.
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@TestParameters("{source: 'cel.block([])'}")
@TestParameters("{source: 'cel.block([1])'}")
Expand Down
Loading