Skip to content

Commit

Permalink
Enforce composed AST stays below the configured depth limit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663089056
  • Loading branch information
l46kok authored and copybara-github committed Aug 15, 2024
1 parent 12d777f commit c3a7528
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 5 deletions.
4 changes: 4 additions & 0 deletions policy/src/main/java/dev/cel/policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ java_library(
"//optimizer",
"//optimizer:optimization_exception",
"//optimizer:optimizer_builder",
"//validator",
"//validator:ast_validator",
"//validator:validator_builder",
"//validator/validators:ast_depth_limit_validator",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ public interface CelPolicyCompilerBuilder {
@CanIgnoreReturnValue
CelPolicyCompilerBuilder setIterationLimit(int iterationLimit);

/**
* Enforces the composed AST to stay below the configured depth limit. An exception is thrown if
* the depth exceeds the configured limit. Setting a negative value disables this check.
*/
@CanIgnoreReturnValue
CelPolicyCompilerBuilder setAstDepthLimit(int iterationLimit);

@CheckReturnValue
CelPolicyCompiler build();
}
49 changes: 46 additions & 3 deletions policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import dev.cel.common.CelSource;
import dev.cel.common.CelSourceLocation;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelValidationResult;
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelConstant;
import dev.cel.common.ast.CelExpr;
Expand All @@ -40,6 +41,10 @@
import dev.cel.policy.CelPolicy.Match;
import dev.cel.policy.CelPolicy.Variable;
import dev.cel.policy.RuleComposer.RuleCompositionException;
import dev.cel.validator.CelAstValidator;
import dev.cel.validator.CelValidator;
import dev.cel.validator.CelValidatorFactory;
import dev.cel.validator.validators.AstDepthLimitValidator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand All @@ -52,6 +57,7 @@ final class CelPolicyCompilerImpl implements CelPolicyCompiler {
private final Cel cel;
private final String variablesPrefix;
private final int iterationLimit;
private final Optional<CelAstValidator> astDepthValidator;

@Override
public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationException {
Expand All @@ -67,8 +73,9 @@ public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationE
@Override
public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledRule)
throws CelPolicyValidationException {
Cel cel = compiledRule.cel();
CelOptimizer optimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(compiledRule.cel())
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
.addAstOptimizers(
RuleComposer.newInstance(compiledRule, variablesPrefix, iterationLimit))
.build();
Expand Down Expand Up @@ -105,9 +112,26 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR
throw new CelPolicyValidationException("Unexpected error while composing rules.", e);
}

assertAstDepthIsSafe(ast, cel);

return ast;
}

private void assertAstDepthIsSafe(CelAbstractSyntaxTree ast, Cel cel)
throws CelPolicyValidationException {
if (!astDepthValidator.isPresent()) {
return;
}
CelValidator celValidator =
CelValidatorFactory.standardCelValidatorBuilder(cel)
.addAstValidators(astDepthValidator.get())
.build();
CelValidationResult result = celValidator.validate(ast);
if (result.hasError()) {
throw new CelPolicyValidationException(result.getErrorString());
}
}

private CelCompiledRule compileRuleImpl(
CelPolicy.Rule rule, Cel ruleCel, CompilerContext compilerContext) {
ImmutableList.Builder<CelCompiledVariable> variableBuilder = ImmutableList.builder();
Expand Down Expand Up @@ -262,9 +286,11 @@ static final class Builder implements CelPolicyCompilerBuilder {
private final Cel cel;
private String variablesPrefix;
private int iterationLimit;
private Optional<CelAstValidator> astDepthLimitValidator;

private Builder(Cel cel) {
this.cel = cel;
this.astDepthLimitValidator = Optional.of(AstDepthLimitValidator.DEFAULT);
}

@Override
Expand All @@ -281,9 +307,21 @@ public Builder setIterationLimit(int iterationLimit) {
return this;
}

@Override
@CanIgnoreReturnValue
public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) {
if (astDepthLimit < 0) {
astDepthLimitValidator = Optional.empty();
} else {
astDepthLimitValidator = Optional.of(AstDepthLimitValidator.newInstance(astDepthLimit));
}
return this;
}

@Override
public CelPolicyCompiler build() {
return new CelPolicyCompilerImpl(cel, this.variablesPrefix, this.iterationLimit);
return new CelPolicyCompilerImpl(
cel, this.variablesPrefix, this.iterationLimit, astDepthLimitValidator);
}
}

Expand All @@ -293,9 +331,14 @@ static Builder newBuilder(Cel cel) {
.setIterationLimit(DEFAULT_ITERATION_LIMIT);
}

private CelPolicyCompilerImpl(Cel cel, String variablesPrefix, int iterationLimit) {
private CelPolicyCompilerImpl(
Cel cel,
String variablesPrefix,
int iterationLimit,
Optional<CelAstValidator> astDepthValidator) {
this.cel = checkNotNull(cel);
this.variablesPrefix = checkNotNull(variablesPrefix);
this.iterationLimit = iterationLimit;
this.astDepthValidator = astDepthValidator;
}
}
36 changes: 36 additions & 0 deletions policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,42 @@ public void compileYamlPolicy_multilineContainsError_throws(
assertThat(e).hasMessageThat().isEqualTo(testCase.expected);
}

@Test
public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Exception {
String longExpr =
"0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50";
String policyContent =
String.format(
"name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr);
CelPolicy policy = POLICY_PARSER.parse(policyContent);

CelPolicyValidationException e =
assertThrows(
CelPolicyValidationException.class,
() -> CelPolicyCompilerFactory.newPolicyCompiler(newCel()).build().compile(policy));

assertThat(e)
.hasMessageThat()
.isEqualTo("ERROR: <input>:-1:0: AST's depth exceeds the configured limit: 50.");
}

@Test
public void compileYamlPolicy_astDepthLimitCheckDisabled_doesNotThrow() throws Exception {
String longExpr =
"0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50";
String policyContent =
String.format(
"name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr);
CelPolicy policy = POLICY_PARSER.parse(policyContent);

CelAbstractSyntaxTree ast =
CelPolicyCompilerFactory.newPolicyCompiler(newCel())
.setAstDepthLimit(-1)
.build()
.compile(policy);
assertThat(ast).isNotNull();
}

@Test
@SuppressWarnings("unchecked")
public void evaluateYamlPolicy_withCanonicalTestData(
Expand Down
10 changes: 8 additions & 2 deletions validator/src/main/java/dev/cel/validator/CelAstValidator.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import dev.cel.common.CelIssue;
import dev.cel.common.CelIssue.Severity;
import dev.cel.common.CelSource;
import dev.cel.common.CelSourceLocation;
import dev.cel.common.navigation.CelNavigableAst;
import java.util.Optional;

/** Public interface for performing a single, custom validation on an AST. */
public interface CelAstValidator {
Expand Down Expand Up @@ -53,12 +55,16 @@ public void addInfo(long exprId, String message) {

private void add(long exprId, String message, Severity severity) {
CelSource source = navigableAst.getAst().getSource();
int position = source.getPositionsMap().get(exprId);
int position = Optional.ofNullable(source.getPositionsMap().get(exprId)).orElse(-1);
CelSourceLocation sourceLocation = CelSourceLocation.NONE;
if (position >= 0) {
sourceLocation = source.getOffsetLocation(position).get();
}
issuesBuilder.add(
CelIssue.newBuilder()
.setSeverity(severity)
.setMessage(message)
.setSourceLocation(source.getOffsetLocation(position).get())
.setSourceLocation(sourceLocation)
.build());
}

Expand Down

0 comments on commit c3a7528

Please sign in to comment.