Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assign unique indices for mangled comprehension identifiers with different types #243

Merged
1 commit merged into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ java_library(
"//common/ast",
"//common/ast:expr_factory",
"//common/navigation",
"//common/types:type_providers",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
],
Expand Down
117 changes: 95 additions & 22 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Table;
import com.google.errorprone.annotations.Immutable;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelSource;
Expand All @@ -38,9 +39,13 @@
import dev.cel.common.navigation.CelNavigableAst;
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder;
import dev.cel.common.types.CelType;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.stream.Collectors;

/** MutableAst contains logic for mutating a {@link CelAbstractSyntaxTree}. */
@Immutable
Expand Down Expand Up @@ -187,45 +192,112 @@ public CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast)
*
* <p>The expression IDs are not modified when the identifier names are changed.
*
* <p>Mangling occurs only if the iteration variable is referenced within the loop step.
*
* <p>Iteration variables in comprehensions are numbered based on their comprehension nesting
* levels. Examples:
* levels and the iteration variable's type. Examples:
*
* <ul>
* <li>{@code [true].exists(i, i) && [true].exists(j, j)} -> {@code [true].exists(@c0, @c0) &&
* [true].exists(@c0, @c0)} // Note that i,j gets replaced to the same @c0 in this example
* <li>{@code [true].exists(i, i && [true].exists(j, j))} -> {@code [true].exists(@c0, @c0 &&
* [true].exists(@c1, @c1))}
* <li>{@code [true].exists(i, i) && [true].exists(j, j)} -> {@code [true].exists(@c0:0, @c0:0)
* && [true].exists(@c0:0, @c0:0)} // Note that i,j gets replaced to the same @c0:0 in this
* example as they share the same nesting level and type.
* <li>{@code [1].exists(i, i > 0) && [1u].exists(j, j > 0u)} -> {@code [1].exists(@c0:0, @c0:0
* > 0) && [1u].exists(@c0:1, @c0:1 > 0u)}
* <li>{@code [true].exists(i, i && [true].exists(j, j))} -> {@code [true].exists(@c0:0, @c0:0
* && [true].exists(@c1:0, @c1:0))}
* </ul>
*
* @param ast AST to mutate
* @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will
* produce @c0, @c1, @c2... as new names.
* produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
*/
public MangledComprehensionAst mangleComprehensionIdentifierNames(
CelAbstractSyntaxTree ast, String newIdentPrefix) {
int iterCount;
CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast);
ImmutableSet.Builder<String> mangledComprehensionIdents = ImmutableSet.builder();
for (iterCount = 0; iterCount < iterationLimit; iterCount++) {
LinkedHashMap<CelNavigableExpr, CelType> comprehensionsToMangle =
newNavigableAst
.getRoot()
// This is important - mangling needs to happen bottom-up to avoid stepping over
// shadowed variables that are not part of the comprehension being mangled.
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix))
.filter(
node -> {
// Ensure the iter_var is actually referenced in the loop_step. If it's not, we
// can skip mangling.
String iterVar = node.expr().comprehension().iterVar();
return CelNavigableExpr.fromExpr(node.expr().comprehension().loopStep())
.allNodes()
.anyMatch(
subNode -> subNode.expr().identOrDefault().name().contains(iterVar));
})
.collect(
Collectors.toMap(
k -> k,
v -> {
String iterVar = v.expr().comprehension().iterVar();
long iterVarId =
CelNavigableExpr.fromExpr(v.expr().comprehension().loopStep())
.allNodes()
.filter(
loopStepNode ->
loopStepNode.expr().identOrDefault().name().equals(iterVar))
.map(CelNavigableExpr::id)
.findAny()
.orElseThrow(
() -> {
throw new NoSuchElementException(
"Expected iteration variable to exist in expr id: "
+ v.id());
});

return ast.getType(iterVarId)
.orElseThrow(
() ->
new NoSuchElementException(
"Checked type not present for: " + iterVarId));
},
(x, y) -> {
throw new IllegalStateException("Unexpected CelNavigableExpr collision");
},
LinkedHashMap::new));
int iterCount = 0;

// The map that we'll eventually return to the caller.
HashMap<String, CelType> mangledIdentNamesToType = new HashMap<>();
// Intermediary table used for the purposes of generating a unique mangled variable name.
Table<Integer, CelType, String> comprehensionLevelToType = HashBasedTable.create();
for (Entry<CelNavigableExpr, CelType> comprehensionEntry : comprehensionsToMangle.entrySet()) {
iterCount++;
// Refetch the comprehension node as mutating the AST could have renumbered its IDs.
CelNavigableExpr comprehensionNode =
newNavigableAst
.getRoot()
// This is important - mangling needs to happen bottom-up to avoid stepping over
// shadowed variables that are not part of the comprehension being mangled.
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix))
.findAny()
.orElse(null);
if (comprehensionNode == null) {
break;
}
.orElseThrow(
() -> new NoSuchElementException("Failed to refetch mutated comprehension"));
CelType comprehensionEntryType = comprehensionEntry.getValue();

CelExpr.Builder comprehensionExpr = comprehensionNode.expr().toBuilder();
String iterVar = comprehensionExpr.comprehension().iterVar();
int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode);
String mangledVarName = newIdentPrefix + comprehensionNestingLevel;
mangledComprehensionIdents.add(mangledVarName);
String mangledVarName;
if (comprehensionLevelToType.contains(comprehensionNestingLevel, comprehensionEntryType)) {
mangledVarName =
comprehensionLevelToType.get(comprehensionNestingLevel, comprehensionEntryType);
} else {
// First time encountering the pair of <ComprehensionLevel, CelType>. Generate a unique
// mangled variable name for this.
int uniqueTypeIdx = comprehensionLevelToType.row(comprehensionNestingLevel).size();
mangledVarName = newIdentPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx;
comprehensionLevelToType.put(
comprehensionNestingLevel, comprehensionEntryType, mangledVarName);
}
mangledIdentNamesToType.put(mangledVarName, comprehensionEntryType);

CelExpr.Builder mutatedComprehensionExpr =
mangleIdentsInComprehensionExpr(
Expand Down Expand Up @@ -254,7 +326,8 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
throw new IllegalStateException("Max iteration count reached.");
}

return MangledComprehensionAst.of(newNavigableAst.getAst(), mangledComprehensionIdents.build());
return MangledComprehensionAst.of(
newNavigableAst.getAst(), ImmutableMap.copyOf(mangledIdentNamesToType));
}

/**
Expand Down Expand Up @@ -588,11 +661,11 @@ public abstract static class MangledComprehensionAst {
/** AST after the iteration variables have been mangled. */
public abstract CelAbstractSyntaxTree ast();

/** Set of identifiers with the iteration variable mangled. */
public abstract ImmutableSet<String> mangledComprehensionIdents();
/** Map containing the mangled identifier names to their types. */
public abstract ImmutableMap<String, CelType> mangledComprehensionIdents();

private static MangledComprehensionAst of(
CelAbstractSyntaxTree ast, ImmutableSet<String> mangledComprehensionIdents) {
CelAbstractSyntaxTree ast, ImmutableMap<String, CelType> mangledComprehensionIdents) {
return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionIdents);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
CelAbstractSyntaxTree astToModify = mangledComprehensionAst.ast();
CelSource sourceToModify = astToModify.getSource();
ImmutableSet<CelVarDecl> mangledIdentDecls =
newMangledIdentDecls(celBuilder, mangledComprehensionAst);

int blockIdentifierIndex = 0;
int iterCount;
Expand Down Expand Up @@ -192,7 +190,11 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(

// Add all mangled comprehension identifiers to the environment, so that the subexpressions can
// retain context to them.
celBuilder.addVarDeclarations(mangledIdentDecls);
mangledComprehensionAst
.mangledComprehensionIdents()
.forEach(
(identName, type) ->
celBuilder.addVarDeclarations(CelVarDecl.newVarDeclaration(identName, type)));
// Type-check all sub-expressions then add them as block identifiers to the CEL environment
addBlockIdentsToEnv(celBuilder, subexpressions);

Expand Down Expand Up @@ -260,41 +262,6 @@ private static void addBlockIdentsToEnv(CelBuilder celBuilder, List<CelExpr> sub
}
}

private static ImmutableSet<CelVarDecl> newMangledIdentDecls(
CelBuilder celBuilder, MangledComprehensionAst mangledComprehensionAst) {
if (mangledComprehensionAst.mangledComprehensionIdents().isEmpty()) {
return ImmutableSet.of();
}
CelAbstractSyntaxTree ast = mangledComprehensionAst.ast();
try {
ast = celBuilder.build().check(ast).getAst();
} catch (CelValidationException e) {
throw new IllegalStateException("Failed to type-check mangled AST.", e);
}

ImmutableSet.Builder<CelVarDecl> mangledVarDecls = ImmutableSet.builder();
for (String ident : mangledComprehensionAst.mangledComprehensionIdents()) {
CelExpr mangledIdentExpr =
CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes()
.filter(node -> node.getKind().equals(Kind.IDENT))
.map(CelNavigableExpr::expr)
.filter(expr -> expr.ident().name().equals(ident))
.findAny()
.orElse(null);
if (mangledIdentExpr == null) {
break;
}

CelType mangledIdentType =
ast.getType(mangledIdentExpr.id()).orElseThrow(() -> new NoSuchElementException("?"));
mangledVarDecls.add(CelVarDecl.newVarDeclaration(ident, mangledIdentType));
}

return mangledVarDecls.build();
}

private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst) {
CelAbstractSyntaxTree astToModify =
mutableAst
Expand Down
16 changes: 8 additions & 8 deletions optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
assertThat(mangledAst.getExpr().toString())
.isEqualTo(
"COMPREHENSION [13] {\n"
+ " iter_var: @c0\n"
+ " iter_var: @c0:0\n"
+ " iter_range: {\n"
+ " CREATE_LIST [1] {\n"
+ " elements: {\n"
Expand Down Expand Up @@ -722,7 +722,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
+ " name: __result__\n"
+ " }\n"
+ " IDENT [5] {\n"
+ " name: @c0\n"
+ " name: @c0:0\n"
+ " }\n"
+ " }\n"
+ " }\n"
Expand All @@ -733,7 +733,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
+ " }\n"
+ " }\n"
+ "}");
assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("[false].exists(@c0, @c0)");
assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("[false].exists(@c0:0, @c0:0)");
assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval()).isEqualTo(false);
assertConsistentMacroCalls(ast);
}
Expand All @@ -748,7 +748,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
assertThat(mangledAst.getExpr().toString())
.isEqualTo(
"COMPREHENSION [27] {\n"
+ " iter_var: @c0\n"
+ " iter_var: @c0:0\n"
+ " iter_range: {\n"
+ " CREATE_LIST [1] {\n"
+ " elements: {\n"
Expand Down Expand Up @@ -785,12 +785,12 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
+ " name: __result__\n"
+ " }\n"
+ " COMPREHENSION [19] {\n"
+ " iter_var: @c1\n"
+ " iter_var: @c1:0\n"
+ " iter_range: {\n"
+ " CREATE_LIST [5] {\n"
+ " elements: {\n"
+ " IDENT [6] {\n"
+ " name: @c0\n"
+ " name: @c0:0\n"
+ " }\n"
+ " }\n"
+ " }\n"
Expand Down Expand Up @@ -825,7 +825,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
+ " function: _==_\n"
+ " args: {\n"
+ " IDENT [9] {\n"
+ " name: @c1\n"
+ " name: @c1:0\n"
+ " }\n"
+ " CONSTANT [11] { value: 1 }\n"
+ " }\n"
Expand All @@ -850,7 +850,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
+ "}");

assertThat(CEL_UNPARSER.unparse(mangledAst))
.isEqualTo("[x].exists(@c0, [@c0].exists(@c1, @c1 == 1))");
.isEqualTo("[x].exists(@c0:0, [@c0:0].exists(@c1:0, @c1:0 == 1))");
assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval(ImmutableMap.of("x", 1)))
.isEqualTo(true);
assertConsistentMacroCalls(ast);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ java_library(
testonly = 1,
srcs = glob(["*.java"]),
deps = [
"//:java_truth",
# "//java/com/google/testing/testsize:annotations",
"//bundle:cel",
"//common",
"//common:compiler_common",
Expand All @@ -28,16 +28,18 @@ java_library(
"//parser:operator",
"//parser:unparser",
"//runtime",
"@maven//:com_google_guava_guava",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
"@maven//:junit_junit",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
"//:java_truth",
"@maven//:com_google_guava_guava",
],
)

junit4_test_suites(
name = "test_suites",
sizes = [
"small",
"medium",
],
src_dir = "src/test/java",
deps = [":tests"],
Expand Down
Loading
Loading