Skip to content

Commit

Permalink
Improve canEliminate logic by excluding nodes from ineligible compreh…
Browse files Browse the repository at this point in the history
…ension branches

PiperOrigin-RevId: 627768195
  • Loading branch information
l46kok authored and copybara-github committed Apr 24, 2024
1 parent ec41e7e commit 91b20d5
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 31 deletions.
5 changes: 5 additions & 0 deletions common/src/main/java/dev/cel/common/ast/CelMutableExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,11 @@ public boolean equals(Object obj) {
return false;
}

@Override
public String toString() {
return CelMutableExprConverter.fromMutableExpr(this).toString();
}

@Override
public int hashCode() {
int h = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.toCollection;

import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -65,6 +66,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

/**
* Performs Common Subexpression Elimination.
Expand Down Expand Up @@ -469,11 +471,12 @@ private List<CelMutableExpr> getCseCandidates(CelNavigableMutableAst navAst) {
private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
CelNavigableMutableAst navAst, int recursionLimit) {
Preconditions.checkArgument(recursionLimit > 0);
Set<CelMutableExpr> ineligibleExprs = getIneligibleExprsFromComprehensionBranches(navAst);
ImmutableList<CelNavigableMutableExpr> descendants =
navAst
.getRoot()
.descendants(TraversalOrder.PRE_ORDER)
.filter(this::canEliminate)
.filter(node -> canEliminate(node, ineligibleExprs))
.filter(node -> node.height() <= recursionLimit)
.sorted(Comparator.comparingInt(CelNavigableMutableExpr::height).reversed())
.collect(toImmutableList());
Expand All @@ -494,7 +497,7 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
.getRoot()
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.height() > recursionLimit)
.anyMatch(this::canEliminate);
.anyMatch(node -> canEliminate(node, ineligibleExprs));
if (astHasMoreExtractableSubexprs) {
cseCandidates.add(descendants.get(0).expr());
return cseCandidates;
Expand All @@ -506,11 +509,12 @@ private List<CelMutableExpr> getCseCandidatesWithRecursionDepth(
}

private List<CelMutableExpr> getCseCandidatesWithCommonSubexpr(CelNavigableMutableAst navAst) {
Set<CelMutableExpr> ineligibleExprs = getIneligibleExprsFromComprehensionBranches(navAst);
ImmutableList<CelNavigableMutableExpr> allNodes =
navAst
.getRoot()
.allNodes(TraversalOrder.PRE_ORDER)
.filter(this::canEliminate)
.filter(node -> canEliminate(node, ineligibleExprs))
.collect(toImmutableList());

return getCseCandidatesWithCommonSubexpr(allNodes);
Expand Down Expand Up @@ -547,42 +551,48 @@ private List<CelMutableExpr> getCseCandidatesWithCommonSubexpr(
return cseCandidates;
}

private boolean canEliminate(CelNavigableMutableExpr navigableExpr) {
private boolean canEliminate(
CelNavigableMutableExpr navigableExpr, Set<CelMutableExpr> ineligibleExprs) {
return !navigableExpr.getKind().equals(Kind.CONSTANT)
&& !navigableExpr.getKind().equals(Kind.IDENT)
&& !(navigableExpr.getKind().equals(Kind.IDENT)
&& navigableExpr.expr().ident().name().startsWith(BIND_IDENTIFIER_PREFIX))
// Exclude empty lists (cel.bind sets this for iterRange).
&& !(navigableExpr.getKind().equals(Kind.CREATE_LIST)
&& navigableExpr.expr().createList().elements().isEmpty())
&& containsEliminableFunctionOnly(navigableExpr)
&& isWithinInlineableComprehension(navigableExpr);
&& !ineligibleExprs.contains(navigableExpr.expr());
}

private static boolean isWithinInlineableComprehension(CelNavigableMutableExpr expr) {
Optional<CelNavigableMutableExpr> maybeParent = expr.parent();
while (maybeParent.isPresent()) {
CelNavigableMutableExpr parent = maybeParent.get();
if (parent.getKind().equals(Kind.COMPREHENSION)) {
return Streams.concat(
// If the expression is within a comprehension, it is eligible for CSE iff is in
// result, loopStep or iterRange. While result is not human authored, it needs to be
// included to extract subexpressions that are already in cel.bind macro.
CelNavigableMutableExpr.fromExpr(parent.expr().comprehension().result())
.descendants(),
CelNavigableMutableExpr.fromExpr(parent.expr().comprehension().loopStep())
.allNodes(),
CelNavigableMutableExpr.fromExpr(parent.expr().comprehension().iterRange())
.allNodes())
.filter(
node ->
// Exclude empty lists (cel.bind sets this for iterRange).
!node.getKind().equals(Kind.CREATE_LIST)
|| !node.expr().createList().elements().isEmpty())
.map(CelNavigableMutableExpr::expr)
.anyMatch(node -> node.equals(expr.expr()));
}
maybeParent = parent.parent();
}
/**
* Collects a set of nodes that are not eligible to be optimized from comprehension branches.
*
* <p>All nodes from accumulator initializer and loop condition are not eligible to be optimized
* as that can interfere with scoping of shadowed variables.
*/
private static Set<CelMutableExpr> getIneligibleExprsFromComprehensionBranches(
CelNavigableMutableAst navAst) {
HashSet<CelMutableExpr> ineligibleExprs = new HashSet<>();
navAst
.getRoot()
.allNodes()
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.forEach(
node -> {
Set<CelMutableExpr> nodes =
Streams.concat(
CelNavigableMutableExpr.fromExpr(node.expr().comprehension().accuInit())
.allNodes(),
CelNavigableMutableExpr.fromExpr(
node.expr().comprehension().loopCondition())
.allNodes())
.map(CelNavigableMutableExpr::expr)
.collect(toCollection(HashSet::new));

ineligibleExprs.addAll(nodes);
});

return true;
return ineligibleExprs;
}

private boolean containsEliminableFunctionOnly(CelNavigableMutableExpr navigableExpr) {
Expand Down

0 comments on commit 91b20d5

Please sign in to comment.