Skip to content

Commit

Permalink
Augment CSE to produce optimized ASTs using cel.block
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607486802
  • Loading branch information
l46kok authored and copybara-github committed Feb 16, 2024
1 parent bda6026 commit 70ef6f9
Show file tree
Hide file tree
Showing 5 changed files with 642 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

package dev.cel.optimizer;

import dev.cel.bundle.Cel;
import dev.cel.bundle.CelBuilder;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.navigation.CelNavigableAst;

/** Public interface for performing a single, custom optimization on an AST. */
public interface CelAstOptimizer {

/** Optimizes a single AST. */
CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel)
CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, CelBuilder cel)
throws CelOptimizationException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.common.collect.ImmutableSet;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelBuilder;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelValidationException;
import dev.cel.common.navigation.CelNavigableAst;
Expand All @@ -39,11 +40,12 @@ public CelAbstractSyntaxTree optimize(CelAbstractSyntaxTree ast) throws CelOptim
}

CelAbstractSyntaxTree optimizedAst = ast;
CelBuilder celBuilder = cel.toCelBuilder();
try {
for (CelAstOptimizer optimizer : astOptimizers) {
CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast);
optimizedAst = optimizer.optimize(navigableAst, cel);
optimizedAst = cel.check(optimizedAst).getAst();
CelNavigableAst navigableAst = CelNavigableAst.fromAst(optimizedAst);
optimizedAst = optimizer.optimize(navigableAst, celBuilder);
optimizedAst = celBuilder.build().check(optimizedAst).getAst();
}
} catch (CelValidationException e) {
throw new CelOptimizationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelBuilder;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelValidationException;
import dev.cel.common.ast.CelConstant;
Expand Down Expand Up @@ -76,8 +77,9 @@ public static ConstantFoldingOptimizer newInstance(
private final MutableAst mutableAst;

@Override
public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel)
public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, CelBuilder celBuilder)
throws CelOptimizationException {
Cel cel = celBuilder.build();
Set<CelExpr> visitedExprs = new HashSet<>();
int iterCount = 0;
while (true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelBuilder;
import dev.cel.checker.Standard;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.CelSource;
import dev.cel.common.CelValidationException;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelIdent;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
import dev.cel.common.navigation.CelNavigableAst;
Expand All @@ -41,8 +43,10 @@
import dev.cel.optimizer.CelAstOptimizer;
import dev.cel.optimizer.MutableAst;
import dev.cel.parser.Operator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.stream.Stream;
Expand All @@ -63,6 +67,12 @@
* cel.bind(@r0, message.child.text_map[x],
* @r0.startsWith("hello") && @r0.endsWith("world"))
* }
*
* Or, using the equivalent form of cel.@block (requires special runtime support):
* {@code
* cel.block([message.child.text_map[x]],
* @index0.startsWith("hello") && @index1.endsWith("world"))
* }
* </pre>
*/
public class SubexpressionOptimizer implements CelAstOptimizer {
Expand All @@ -71,6 +81,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
private static final String BIND_IDENTIFIER_PREFIX = "@r";
private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c";
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
private static final String BLOCK_INDEX_PREFIX = "@index";
private static final ImmutableSet<String> CSE_ALLOWED_FUNCTIONS =
Streams.concat(
stream(Operator.values()).map(Operator::getFunction),
Expand All @@ -96,7 +107,138 @@ public static SubexpressionOptimizer newInstance(SubexpressionOptimizerOptions c
}

@Override
public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, CelBuilder celBuilder) {
return cseOptions.enableCelBlock()
? optimizeUsingCelBlock(navigableAst, celBuilder)
: optimizeUsingCelBind(navigableAst);
}

private CelAbstractSyntaxTree optimizeUsingCelBlock(
CelNavigableAst navigableAst, CelBuilder celBuilder) {
// Retain the original expected result type, so that it can be reset in celBuilder at the end of
// the optimization pass.
CelType resultType = navigableAst.getAst().getResultType();
CelAbstractSyntaxTree astToModify =
mutableAst.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
CelSource sourceToModify = astToModify.getSource();

int blockIdentifierIndex = 0;
int iterCount;
ArrayList<CelExpr> subexpressions = new ArrayList<>();
for (iterCount = 0; iterCount < cseOptions.iterationLimit(); iterCount++) {
CelExpr cseCandidate = findCseCandidate(astToModify).map(CelNavigableExpr::expr).orElse(null);
if (cseCandidate == null) {
break;
}
subexpressions.add(cseCandidate);

String blockIdentifier = BLOCK_INDEX_PREFIX + blockIdentifierIndex++;

// Using the CSE candidate, fetch all semantically equivalent subexpressions ahead of time.
ImmutableList<CelExpr> allCseCandidates =
getAllCseCandidatesStream(astToModify, cseCandidate).collect(toImmutableList());

// Replace all CSE candidates with new block index identifier
for (CelExpr semanticallyEqualNode : allCseCandidates) {
iterCount++;
// Refetch the candidate expr as mutating the AST could have renumbered its IDs.
CelExpr exprToReplace =
getAllCseCandidatesStream(astToModify, semanticallyEqualNode)
.findAny()
.orElseThrow(
() ->
new NoSuchElementException(
"No value present for expr ID: " + semanticallyEqualNode.id()));

astToModify =
mutableAst.replaceSubtree(
astToModify,
CelExpr.newBuilder()
.setIdent(CelIdent.newBuilder().setName(blockIdentifier).build())
.build(),
exprToReplace.id());
}

sourceToModify =
sourceToModify.toBuilder()
.addAllMacroCalls(astToModify.getSource().getMacroCalls())
.build();
astToModify = CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), sourceToModify);

// Retain the existing macro calls in case if the block identifiers are replacing a subtree
// that contains a comprehension.
sourceToModify = astToModify.getSource();
}

if (iterCount >= cseOptions.iterationLimit()) {
throw new IllegalStateException("Max iteration count reached.");
}

if (iterCount == 0) {
// No modification has been made.
return astToModify;
}

// Type-check all sub-expressions then add them as block identifiers to the CEL environment
addBlockIdentsToEnv(celBuilder, subexpressions);

// Wrap the optimized expression in cel.block
celBuilder.addFunctionDeclarations(newCelBlockFunctionDecl(resultType));
int newId = 0;
CelExpr blockExpr =
CelExpr.newBuilder()
.setId(++newId)
.setCall(
CelCall.newBuilder()
.setFunction(CEL_BLOCK_FUNCTION)
.addArgs(
CelExpr.ofCreateListExpr(
++newId, ImmutableList.copyOf(subexpressions), ImmutableList.of()),
astToModify.getExpr())
.build())
.build();
astToModify =
mutableAst.renumberIdsConsecutively(
CelAbstractSyntaxTree.newParsedAst(blockExpr, astToModify.getSource()));

if (!cseOptions.populateMacroCalls()) {
astToModify =
CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), CelSource.newBuilder().build());
}

// Restore the expected result type the environment had prior to optimization.
celBuilder.setResultType(resultType);
return astToModify;
}

/**
* Adds all subexpression as numbered identifiers that acts as an indexer to cel.block
* (ex: @index0, @index1..) Each subexpressions are type-checked, then its result type is used as
* the new identifiers' types.
*/
private static void addBlockIdentsToEnv(CelBuilder celBuilder, List<CelExpr> subexpressions) {
// The resulting type of the subexpressions will likely be different from the
// entire expression's expected result type.
celBuilder.setResultType(SimpleType.DYN);

for (int i = 0; i < subexpressions.size(); i++) {
CelExpr subexpression = subexpressions.get(i);

CelAbstractSyntaxTree subAst =
CelAbstractSyntaxTree.newParsedAst(subexpression, CelSource.newBuilder().build());

try {
subAst = celBuilder.build().check(subAst).getAst();
} catch (CelValidationException e) {
throw new IllegalStateException("Failed to type-check subexpression", e);
}

celBuilder.addVar("@index" + i, subAst.getResultType());
}
}

private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst) {
CelAbstractSyntaxTree astToModify =
mutableAst.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
Expand Down Expand Up @@ -166,12 +308,13 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
return astToModify;
}

astToModify = mutableAst.renumberIdsConsecutively(astToModify);
if (!cseOptions.populateMacroCalls()) {
astToModify =
CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), CelSource.newBuilder().build());
}

return mutableAst.renumberIdsConsecutively(astToModify);
return astToModify;
}

private Stream<CelExpr> getAllCseCandidatesStream(
Expand Down Expand Up @@ -347,6 +490,8 @@ public abstract static class SubexpressionOptimizerOptions {

public abstract boolean populateMacroCalls();

public abstract boolean enableCelBlock();

/** Builder for configuring the {@link SubexpressionOptimizerOptions}. */
@AutoValue.Builder
public abstract static class Builder {
Expand All @@ -363,6 +508,12 @@ public abstract static class Builder {
*/
public abstract Builder populateMacroCalls(boolean value);

/**
* Rewrites the optimized AST using cel.@block call instead of cascaded cel.bind macros, aimed
* to produce a more compact AST.
*/
public abstract Builder enableCelBlock(boolean value);

public abstract SubexpressionOptimizerOptions build();

Builder() {}
Expand All @@ -372,7 +523,8 @@ public abstract static class Builder {
public static Builder newBuilder() {
return new AutoValue_SubexpressionOptimizer_SubexpressionOptimizerOptions.Builder()
.iterationLimit(500)
.populateMacroCalls(false);
.populateMacroCalls(false)
.enableCelBlock(false);
}

SubexpressionOptimizerOptions() {}
Expand Down
Loading

0 comments on commit 70ef6f9

Please sign in to comment.