Skip to content

Commit

Permalink
Add a validator for enforcing AST depth limit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663141870
  • Loading branch information
l46kok authored and copybara-github committed Aug 15, 2024
1 parent 89ea79c commit 12d777f
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dev.cel.validator.validators;

import static com.google.common.base.Preconditions.checkArgument;

import dev.cel.bundle.Cel;
import dev.cel.common.navigation.CelNavigableAst;
import dev.cel.validator.CelAstValidator;

/** Enforces a compiled AST to stay below the configured depth limit. */
public final class AstDepthLimitValidator implements CelAstValidator {

// Protobuf imposes a default parse-depth limit of 100. We set it to half here because navigable
// expr does not include operands in the depth calculation.
// As an example, an expression 'x.y' has a depth of 2 in NavigableExpr, but the ParsedExpr has a
// depth of 4 as illustrated below:
//
// expr {
// id: 2
// select_expr {
// operand {
// id: 1
// ident_expr {
// name: "x"
// }
// }
// field: "y"
// }
// }
static final int DEFAULT_DEPTH_LIMIT = 50;

public static final AstDepthLimitValidator DEFAULT = newInstance(DEFAULT_DEPTH_LIMIT);
private final int maxDepth;

/**
* Constructs a new instance of {@link AstDepthLimitValidator} with the configured maxDepth as its
* limit.
*/
public static AstDepthLimitValidator newInstance(int maxDepth) {
checkArgument(maxDepth > 0);
return new AstDepthLimitValidator(maxDepth);
}

@Override
public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) {
if (navigableAst.getRoot().height() >= maxDepth) {
issuesFactory.addError(
navigableAst.getRoot().id(),
String.format("AST's depth exceeds the configured limit: %s.", maxDepth));
}
}

private AstDepthLimitValidator(int maxDepth) {
this.maxDepth = maxDepth;
}
}
15 changes: 15 additions & 0 deletions validator/src/main/java/dev/cel/validator/validators/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ java_library(
],
)

java_library(
name = "ast_depth_limit_validator",
srcs = [
"AstDepthLimitValidator.java",
],
tags = [
],
deps = [
"//bundle:cel",
"//common/navigation",
"//validator:ast_validator",
"@maven//:com_google_guava_guava",
],
)

java_library(
name = "literal_validator",
srcs = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dev.cel.validator.validators;

import static com.google.common.truth.Truth.assertThat;
import static dev.cel.common.CelFunctionDecl.newFunctionDeclaration;
import static dev.cel.common.CelOverloadDecl.newGlobalOverload;
import static dev.cel.validator.validators.AstDepthLimitValidator.DEFAULT_DEPTH_LIMIT;
import static org.junit.Assert.assertThrows;

import dev.cel.expr.CheckedExpr;
import com.google.protobuf.ByteString;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelFactory;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelIssue.Severity;
import dev.cel.common.CelProtoAbstractSyntaxTree;
import dev.cel.common.CelValidationResult;
import dev.cel.common.types.SimpleType;
import dev.cel.validator.CelValidator;
import dev.cel.validator.CelValidatorFactory;
import org.junit.Test;
import org.junit.runner.RunWith;

@RunWith(TestParameterInjector.class)
public class AstDepthLimitValidatorTest {

private static final Cel CEL =
CelFactory.standardCelBuilder()
.addVar("x", SimpleType.DYN)
.addFunctionDeclarations(
newFunctionDeclaration(
"f", newGlobalOverload("f_int64", SimpleType.INT, SimpleType.INT)))
.build();

private static final CelValidator CEL_VALIDATOR =
CelValidatorFactory.standardCelValidatorBuilder(CEL)
.addAstValidators(AstDepthLimitValidator.DEFAULT)
.build();

private enum DefaultTestCase {
NESTED_SELECTS(
"x.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y.y"),
NESTED_CALCS(
"0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50"),
NESTED_FUNCS(
"f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(0)))))))))))))))))))))))))))))))))))))))))))))))))))");

private final String expression;

DefaultTestCase(String expression) {
this.expression = expression;
}
}

@Test
public void astExceedsDefaultDepthLimit_populatesErrors(@TestParameter DefaultTestCase testCase)
throws Exception {
CelAbstractSyntaxTree ast = CEL.compile(testCase.expression).getAst();

CelValidationResult result = CEL_VALIDATOR.validate(ast);

assertThat(result.hasError()).isTrue();
assertThat(result.getAllIssues()).hasSize(1);
assertThat(result.getAllIssues().get(0).getSeverity()).isEqualTo(Severity.ERROR);
assertThat(result.getAllIssues().get(0).toDisplayString(ast.getSource()))
.contains("AST's depth exceeds the configured limit: 50.");
assertThrows(InvalidProtocolBufferException.class, () -> verifyProtoAstRoundTrips(ast));
}

@Test
public void astIsUnderDepthLimit_noErrors() throws Exception {
StringBuilder sb = new StringBuilder().append("x");
for (int i = 0; i < DEFAULT_DEPTH_LIMIT - 1; i++) {
sb.append(".y");
}
// Depth level of 49
CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst();

CelValidationResult result = CEL_VALIDATOR.validate(ast);

assertThat(result.hasError()).isFalse();
assertThat(result.getAllIssues()).isEmpty();
verifyProtoAstRoundTrips(ast);
}

private void verifyProtoAstRoundTrips(CelAbstractSyntaxTree ast) throws Exception {
CheckedExpr checkedExpr = CelProtoAbstractSyntaxTree.fromCelAst(ast).toCheckedExpr();
ByteString serialized = checkedExpr.toByteString();
CheckedExpr deserializedCheckedExpr =
CheckedExpr.parseFrom(serialized, ExtensionRegistryLite.getEmptyRegistry());
if (!checkedExpr.equals(deserializedCheckedExpr)) {
throw new IllegalStateException("Expected checked expressions to round trip!");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ java_library(
"//runtime",
"//validator",
"//validator:validator_builder",
"//validator/validators:ast_depth_limit_validator",
"//validator/validators:duration",
"//validator/validators:homogeneous_literal",
"//validator/validators:regex",
"//validator/validators:timestamp",
"@@protobuf~//java/core",
"@cel_spec//proto/cel/expr:expr_java_proto",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java_util",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
Expand Down
5 changes: 5 additions & 0 deletions validator/validators/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ java_library(
name = "homogeneous_literal",
exports = ["//validator/src/main/java/dev/cel/validator/validators:homogeneous_literal"],
)

java_library(
name = "ast_depth_limit_validator",
exports = ["//validator/src/main/java/dev/cel/validator/validators:ast_depth_limit_validator"],
)

0 comments on commit 12d777f

Please sign in to comment.