Skip to content

Commit

Permalink
Assign unique indices for mangled comprehension identifiers with diff…
Browse files Browse the repository at this point in the history
…erent types

PiperOrigin-RevId: 606758852
  • Loading branch information
l46kok authored and copybara-github committed Feb 16, 2024
1 parent 629f85b commit 1078f89
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 96 deletions.
1 change: 1 addition & 0 deletions optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ java_library(
"//common/ast",
"//common/ast:expr_factory",
"//common/navigation",
"//common/types:type_providers",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
],
Expand Down
117 changes: 95 additions & 22 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Table;
import com.google.errorprone.annotations.Immutable;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelSource;
Expand All @@ -38,9 +39,13 @@
import dev.cel.common.navigation.CelNavigableAst;
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder;
import dev.cel.common.types.CelType;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.stream.Collectors;

/** MutableAst contains logic for mutating a {@link CelAbstractSyntaxTree}. */
@Immutable
Expand Down Expand Up @@ -187,45 +192,112 @@ public CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast)
*
* <p>The expression IDs are not modified when the identifier names are changed.
*
* <p>Mangling occurs only if the iteration variable is referenced within the loop step.
*
* <p>Iteration variables in comprehensions are numbered based on their comprehension nesting
* levels. Examples:
* levels and the iteration variable's type. Examples:
*
* <ul>
* <li>{@code [true].exists(i, i) && [true].exists(j, j)} -> {@code [true].exists(@c0, @c0) &&
* [true].exists(@c0, @c0)} // Note that i,j gets replaced to the same @c0 in this example
* <li>{@code [true].exists(i, i && [true].exists(j, j))} -> {@code [true].exists(@c0, @c0 &&
* [true].exists(@c1, @c1))}
* <li>{@code [true].exists(i, i) && [true].exists(j, j)} -> {@code [true].exists(@c0:0, @c0:0)
* && [true].exists(@c0:0, @c0:0)} // Note that i,j gets replaced to the same @c0:0 in this
* example as they share the same nesting level and type.
* <li>{@code [1].exists(i, i > 0) && [1u].exists(j, j > 0u)} -> {@code [1].exists(@c0:0, @c0:0
* > 0) && [1u].exists(@c0:1, @c0:1 > 0u)}
* <li>{@code [true].exists(i, i && [true].exists(j, j))} -> {@code [true].exists(@c0:0, @c0:0
* && [true].exists(@c1:0, @c1:0))}
* </ul>
*
* @param ast AST to mutate
* @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will
* produce @c0, @c1, @c2... as new names.
* produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
*/
public MangledComprehensionAst mangleComprehensionIdentifierNames(
CelAbstractSyntaxTree ast, String newIdentPrefix) {
int iterCount;
CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast);
ImmutableSet.Builder<String> mangledComprehensionIdents = ImmutableSet.builder();
for (iterCount = 0; iterCount < iterationLimit; iterCount++) {
LinkedHashMap<CelNavigableExpr, CelType> comprehensionsToMangle =
newNavigableAst
.getRoot()
// This is important - mangling needs to happen bottom-up to avoid stepping over
// shadowed variables that are not part of the comprehension being mangled.
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix))
.filter(
node -> {
// Ensure the iter_var is actually referenced in the loop_step. If it's not, we
// can skip mangling.
String iterVar = node.expr().comprehension().iterVar();
return CelNavigableExpr.fromExpr(node.expr().comprehension().loopStep())
.allNodes()
.anyMatch(
subNode -> subNode.expr().identOrDefault().name().contains(iterVar));
})
.collect(
Collectors.toMap(
k -> k,
v -> {
String iterVar = v.expr().comprehension().iterVar();
long iterVarId =
CelNavigableExpr.fromExpr(v.expr().comprehension().loopStep())
.allNodes()
.filter(
loopStepNode ->
loopStepNode.expr().identOrDefault().name().equals(iterVar))
.map(CelNavigableExpr::id)
.findAny()
.orElseThrow(
() -> {
throw new NoSuchElementException(
"Expected iteration variable to exist in expr id: "
+ v.id());
});

return ast.getType(iterVarId)
.orElseThrow(
() ->
new NoSuchElementException(
"Checked type not present for: " + iterVarId));
},
(x, y) -> {
throw new IllegalStateException("Unexpected CelNavigableExpr collision");
},
LinkedHashMap::new));
int iterCount = 0;

// The map that we'll eventually return to the caller.
HashMap<String, CelType> mangledIdentNamesToType = new HashMap<>();
// Intermediary table used for the purposes of generating a unique mangled variable name.
Table<Integer, CelType, String> comprehensionLevelToType = HashBasedTable.create();
for (Entry<CelNavigableExpr, CelType> comprehensionEntry : comprehensionsToMangle.entrySet()) {
iterCount++;
// Refetch the comprehension node as mutating the AST could have renumbered its IDs.
CelNavigableExpr comprehensionNode =
newNavigableAst
.getRoot()
// This is important - mangling needs to happen bottom-up to avoid stepping over
// shadowed variables that are not part of the comprehension being mangled.
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix))
.findAny()
.orElse(null);
if (comprehensionNode == null) {
break;
}
.orElseThrow(
() -> new NoSuchElementException("Failed to refetch mutated comprehension"));
CelType comprehensionEntryType = comprehensionEntry.getValue();

CelExpr.Builder comprehensionExpr = comprehensionNode.expr().toBuilder();
String iterVar = comprehensionExpr.comprehension().iterVar();
int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode);
String mangledVarName = newIdentPrefix + comprehensionNestingLevel;
mangledComprehensionIdents.add(mangledVarName);
String mangledVarName;
if (comprehensionLevelToType.contains(comprehensionNestingLevel, comprehensionEntryType)) {
mangledVarName =
comprehensionLevelToType.get(comprehensionNestingLevel, comprehensionEntryType);
} else {
// First time encountering the pair of <ComprehensionLevel, CelType>. Generate a unique
// mangled variable name for this.
int uniqueTypeIdx = comprehensionLevelToType.row(comprehensionNestingLevel).size();
mangledVarName = newIdentPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx;
comprehensionLevelToType.put(
comprehensionNestingLevel, comprehensionEntryType, mangledVarName);
}
mangledIdentNamesToType.put(mangledVarName, comprehensionEntryType);

CelExpr.Builder mutatedComprehensionExpr =
mangleIdentsInComprehensionExpr(
Expand Down Expand Up @@ -254,7 +326,8 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
throw new IllegalStateException("Max iteration count reached.");
}

return MangledComprehensionAst.of(newNavigableAst.getAst(), mangledComprehensionIdents.build());
return MangledComprehensionAst.of(
newNavigableAst.getAst(), ImmutableMap.copyOf(mangledIdentNamesToType));
}

/**
Expand Down Expand Up @@ -588,11 +661,11 @@ public abstract static class MangledComprehensionAst {
/** AST after the iteration variables have been mangled. */
public abstract CelAbstractSyntaxTree ast();

/** Set of identifiers with the iteration variable mangled. */
public abstract ImmutableSet<String> mangledComprehensionIdents();
/** Map containing the mangled identifier names to their types. */
public abstract ImmutableMap<String, CelType> mangledComprehensionIdents();

private static MangledComprehensionAst of(
CelAbstractSyntaxTree ast, ImmutableSet<String> mangledComprehensionIdents) {
CelAbstractSyntaxTree ast, ImmutableMap<String, CelType> mangledComprehensionIdents) {
return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionIdents);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
CelAbstractSyntaxTree astToModify = mangledComprehensionAst.ast();
CelSource sourceToModify = astToModify.getSource();
ImmutableSet<CelVarDecl> mangledIdentDecls =
newMangledIdentDecls(celBuilder, mangledComprehensionAst);

int blockIdentifierIndex = 0;
int iterCount;
Expand Down Expand Up @@ -192,7 +190,11 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(

// Add all mangled comprehension identifiers to the environment, so that the subexpressions can
// retain context to them.
celBuilder.addVarDeclarations(mangledIdentDecls);
mangledComprehensionAst
.mangledComprehensionIdents()
.forEach(
(identName, type) ->
celBuilder.addVarDeclarations(CelVarDecl.newVarDeclaration(identName, type)));
// Type-check all sub-expressions then add them as block identifiers to the CEL environment
addBlockIdentsToEnv(celBuilder, subexpressions);

Expand Down Expand Up @@ -260,41 +262,6 @@ private static void addBlockIdentsToEnv(CelBuilder celBuilder, List<CelExpr> sub
}
}

private static ImmutableSet<CelVarDecl> newMangledIdentDecls(
CelBuilder celBuilder, MangledComprehensionAst mangledComprehensionAst) {
if (mangledComprehensionAst.mangledComprehensionIdents().isEmpty()) {
return ImmutableSet.of();
}
CelAbstractSyntaxTree ast = mangledComprehensionAst.ast();
try {
ast = celBuilder.build().check(ast).getAst();
} catch (CelValidationException e) {
throw new IllegalStateException("Failed to type-check mangled AST.", e);
}

ImmutableSet.Builder<CelVarDecl> mangledVarDecls = ImmutableSet.builder();
for (String ident : mangledComprehensionAst.mangledComprehensionIdents()) {
CelExpr mangledIdentExpr =
CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes()
.filter(node -> node.getKind().equals(Kind.IDENT))
.map(CelNavigableExpr::expr)
.filter(expr -> expr.ident().name().equals(ident))
.findAny()
.orElse(null);
if (mangledIdentExpr == null) {
break;
}

CelType mangledIdentType =
ast.getType(mangledIdentExpr.id()).orElseThrow(() -> new NoSuchElementException("?"));
mangledVarDecls.add(CelVarDecl.newVarDeclaration(ident, mangledIdentType));
}

return mangledVarDecls.build();
}

private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst) {
CelAbstractSyntaxTree astToModify =
mutableAst
Expand Down
16 changes: 8 additions & 8 deletions optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
assertThat(mangledAst.getExpr().toString())
.isEqualTo(
"COMPREHENSION [13] {\n"
+ " iter_var: @c0\n"
+ " iter_var: @c0:0\n"
+ " iter_range: {\n"
+ " CREATE_LIST [1] {\n"
+ " elements: {\n"
Expand Down Expand Up @@ -722,7 +722,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
+ " name: __result__\n"
+ " }\n"
+ " IDENT [5] {\n"
+ " name: @c0\n"
+ " name: @c0:0\n"
+ " }\n"
+ " }\n"
+ " }\n"
Expand All @@ -733,7 +733,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
+ " }\n"
+ " }\n"
+ "}");
assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("[false].exists(@c0, @c0)");
assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("[false].exists(@c0:0, @c0:0)");
assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval()).isEqualTo(false);
assertConsistentMacroCalls(ast);
}
Expand All @@ -748,7 +748,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
assertThat(mangledAst.getExpr().toString())
.isEqualTo(
"COMPREHENSION [27] {\n"
+ " iter_var: @c0\n"
+ " iter_var: @c0:0\n"
+ " iter_range: {\n"
+ " CREATE_LIST [1] {\n"
+ " elements: {\n"
Expand Down Expand Up @@ -785,12 +785,12 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
+ " name: __result__\n"
+ " }\n"
+ " COMPREHENSION [19] {\n"
+ " iter_var: @c1\n"
+ " iter_var: @c1:0\n"
+ " iter_range: {\n"
+ " CREATE_LIST [5] {\n"
+ " elements: {\n"
+ " IDENT [6] {\n"
+ " name: @c0\n"
+ " name: @c0:0\n"
+ " }\n"
+ " }\n"
+ " }\n"
Expand Down Expand Up @@ -825,7 +825,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
+ " function: _==_\n"
+ " args: {\n"
+ " IDENT [9] {\n"
+ " name: @c1\n"
+ " name: @c1:0\n"
+ " }\n"
+ " CONSTANT [11] { value: 1 }\n"
+ " }\n"
Expand All @@ -850,7 +850,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
+ "}");

assertThat(CEL_UNPARSER.unparse(mangledAst))
.isEqualTo("[x].exists(@c0, [@c0].exists(@c1, @c1 == 1))");
.isEqualTo("[x].exists(@c0:0, [@c0:0].exists(@c1:0, @c1:0 == 1))");
assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval(ImmutableMap.of("x", 1)))
.isEqualTo(true);
assertConsistentMacroCalls(ast);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ java_library(
testonly = 1,
srcs = glob(["*.java"]),
deps = [
"//:java_truth",
# "//java/com/google/testing/testsize:annotations",
"//bundle:cel",
"//common",
"//common:compiler_common",
Expand All @@ -28,16 +28,18 @@ java_library(
"//parser:operator",
"//parser:unparser",
"//runtime",
"@maven//:com_google_guava_guava",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
"@maven//:junit_junit",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
"//:java_truth",
"@maven//:com_google_guava_guava",
],
)

junit4_test_suites(
name = "test_suites",
sizes = [
"small",
"medium",
],
src_dir = "src/test/java",
deps = [":tests"],
Expand Down
Loading

0 comments on commit 1078f89

Please sign in to comment.