Skip to content

Commit

Permalink
Properly set subtree height for navigable expr's children
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606736573
  • Loading branch information
l46kok authored and copybara-github committed Feb 13, 2024
1 parent 4a723aa commit bda6026
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ private int visit(CelNavigableExpr.Builder navigableExpr) {
if (navigableExpr.depth() > MAX_DESCENDANTS_RECURSION_DEPTH - 1) {
throw new IllegalStateException("Max recursion depth reached.");
}
if (navigableExpr.depth() > maxDepth) {
return -1;
}
if (traversalOrder.equals(TraversalOrder.PRE_ORDER)) {

boolean addToStream = navigableExpr.depth() <= maxDepth;
if (addToStream && traversalOrder.equals(TraversalOrder.PRE_ORDER)) {
streamBuilder.add(navigableExpr);
}

Expand Down Expand Up @@ -129,7 +128,7 @@ private int visit(CelNavigableExpr.Builder navigableExpr) {
}

navigableExpr.setHeight(height);
if (traversalOrder.equals(TraversalOrder.POST_ORDER)) {
if (addToStream && traversalOrder.equals(TraversalOrder.POST_ORDER)) {
streamBuilder.add(navigableExpr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.UnsignedLong;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import com.google.testing.junit.testparameterinjector.TestParameters;
import dev.cel.common.CelAbstractSyntaxTree;
Expand Down Expand Up @@ -147,6 +148,51 @@ public void add_postOrder_heightSet() throws Exception {
assertThat(allNodeHeights).containsExactly(0, 0, 1, 0, 2).inOrder(); // 1, a, +, 2, +
}

@Test
public void add_fromLeaf_heightSetForParents() throws Exception {
CelCompiler compiler =
CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build();
// Tree shape:
// +
// + 2
// 1 a
CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2").getAst();
CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast);

CelNavigableExpr oneConst =
navigableAst
.getRoot()
.descendants()
.filter(node -> node.expr().constantOrDefault().int64Value() == 1)
.findAny()
.get();
assertThat(oneConst.height()).isEqualTo(0); // 1
assertThat(oneConst.parent().get().height()).isEqualTo(1); // +
assertThat(oneConst.parent().get().parent().get().height()).isEqualTo(2); // root
}

@Test
public void add_children_heightSet(@TestParameter TraversalOrder traversalOrder)
throws Exception {
CelCompiler compiler =
CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build();
// Tree shape:
// +
// + 2
// + a
// 3
CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2 + 3").getAst();
CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast);

ImmutableList<Integer> allNodeHeights =
navigableAst
.getRoot()
.children(traversalOrder)
.map(CelNavigableExpr::height)
.collect(toImmutableList());
assertThat(allNodeHeights).containsExactly(2, 0).inOrder(); // + (2), 2 (0) regardless of order
}

@Test
public void add_filterConstants_allNodesReturned() throws Exception {
CelCompiler compiler =
Expand Down Expand Up @@ -1022,6 +1068,24 @@ public void callExpr_postOrder_heightSet() throws Exception {
assertThat(allNodes).containsExactly(0, 0, 1, 0, 2, 0, 3, 0, 0, 1, 0, 2, 0, 4).inOrder();
}

@Test
public void createList_children_heightSet(@TestParameter TraversalOrder traversalOrder)
throws Exception {
CelCompiler compiler =
CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build();
CelAbstractSyntaxTree ast = compiler.compile("[1, a, (2 + 2), (3 + 4 + 5)]").getAst();

CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast);

ImmutableList<Integer> allNodeHeights =
navigableAst
.getRoot()
.children(traversalOrder)
.map(CelNavigableExpr::height)
.collect(toImmutableList());
assertThat(allNodeHeights).containsExactly(0, 0, 1, 2).inOrder();
}

@Test
public void maxRecursionLimitReached_throws() throws Exception {
StringBuilder sb = new StringBuilder();
Expand Down

0 comments on commit bda6026

Please sign in to comment.