Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce composed AST stays below the configured depth limit #424

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions policy/src/main/java/dev/cel/policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ java_library(
"//optimizer",
"//optimizer:optimization_exception",
"//optimizer:optimizer_builder",
"//validator",
"//validator:ast_validator",
"//validator:validator_builder",
"//validator/validators:ast_depth_limit_validator",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ public interface CelPolicyCompilerBuilder {
@CanIgnoreReturnValue
CelPolicyCompilerBuilder setIterationLimit(int iterationLimit);

/**
* Enforces the composed AST to stay below the configured depth limit. An exception is thrown if
* the depth exceeds the configured limit. Setting a negative value disables this check.
*/
@CanIgnoreReturnValue
CelPolicyCompilerBuilder setAstDepthLimit(int iterationLimit);

@CheckReturnValue
CelPolicyCompiler build();
}
49 changes: 46 additions & 3 deletions policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import dev.cel.common.CelSource;
import dev.cel.common.CelSourceLocation;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelValidationResult;
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelConstant;
import dev.cel.common.ast.CelExpr;
Expand All @@ -40,6 +41,10 @@
import dev.cel.policy.CelPolicy.Match;
import dev.cel.policy.CelPolicy.Variable;
import dev.cel.policy.RuleComposer.RuleCompositionException;
import dev.cel.validator.CelAstValidator;
import dev.cel.validator.CelValidator;
import dev.cel.validator.CelValidatorFactory;
import dev.cel.validator.validators.AstDepthLimitValidator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand All @@ -52,6 +57,7 @@ final class CelPolicyCompilerImpl implements CelPolicyCompiler {
private final Cel cel;
private final String variablesPrefix;
private final int iterationLimit;
private final Optional<CelAstValidator> astDepthValidator;

@Override
public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationException {
Expand All @@ -67,8 +73,9 @@ public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationE
@Override
public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledRule)
throws CelPolicyValidationException {
Cel cel = compiledRule.cel();
CelOptimizer optimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(compiledRule.cel())
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
.addAstOptimizers(
RuleComposer.newInstance(compiledRule, variablesPrefix, iterationLimit))
.build();
Expand Down Expand Up @@ -105,9 +112,26 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR
throw new CelPolicyValidationException("Unexpected error while composing rules.", e);
}

assertAstDepthIsSafe(ast, cel);

return ast;
}

private void assertAstDepthIsSafe(CelAbstractSyntaxTree ast, Cel cel)
throws CelPolicyValidationException {
if (!astDepthValidator.isPresent()) {
return;
}
CelValidator celValidator =
CelValidatorFactory.standardCelValidatorBuilder(cel)
.addAstValidators(astDepthValidator.get())
.build();
CelValidationResult result = celValidator.validate(ast);
if (result.hasError()) {
throw new CelPolicyValidationException(result.getErrorString());
}
}

private CelCompiledRule compileRuleImpl(
CelPolicy.Rule rule, Cel ruleCel, CompilerContext compilerContext) {
ImmutableList.Builder<CelCompiledVariable> variableBuilder = ImmutableList.builder();
Expand Down Expand Up @@ -262,9 +286,11 @@ static final class Builder implements CelPolicyCompilerBuilder {
private final Cel cel;
private String variablesPrefix;
private int iterationLimit;
private Optional<CelAstValidator> astDepthLimitValidator;

private Builder(Cel cel) {
this.cel = cel;
this.astDepthLimitValidator = Optional.of(AstDepthLimitValidator.DEFAULT);
}

@Override
Expand All @@ -281,9 +307,21 @@ public Builder setIterationLimit(int iterationLimit) {
return this;
}

@Override
@CanIgnoreReturnValue
public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) {
if (astDepthLimit < 0) {
astDepthLimitValidator = Optional.empty();
} else {
astDepthLimitValidator = Optional.of(AstDepthLimitValidator.newInstance(astDepthLimit));
}
return this;
}

@Override
public CelPolicyCompiler build() {
return new CelPolicyCompilerImpl(cel, this.variablesPrefix, this.iterationLimit);
return new CelPolicyCompilerImpl(
cel, this.variablesPrefix, this.iterationLimit, astDepthLimitValidator);
}
}

Expand All @@ -293,9 +331,14 @@ static Builder newBuilder(Cel cel) {
.setIterationLimit(DEFAULT_ITERATION_LIMIT);
}

private CelPolicyCompilerImpl(Cel cel, String variablesPrefix, int iterationLimit) {
private CelPolicyCompilerImpl(
Cel cel,
String variablesPrefix,
int iterationLimit,
Optional<CelAstValidator> astDepthValidator) {
this.cel = checkNotNull(cel);
this.variablesPrefix = checkNotNull(variablesPrefix);
this.iterationLimit = iterationLimit;
this.astDepthValidator = astDepthValidator;
}
}
36 changes: 36 additions & 0 deletions policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,42 @@ public void compileYamlPolicy_multilineContainsError_throws(
assertThat(e).hasMessageThat().isEqualTo(testCase.expected);
}

@Test
public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Exception {
String longExpr =
"0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50";
String policyContent =
String.format(
"name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr);
CelPolicy policy = POLICY_PARSER.parse(policyContent);

CelPolicyValidationException e =
assertThrows(
CelPolicyValidationException.class,
() -> CelPolicyCompilerFactory.newPolicyCompiler(newCel()).build().compile(policy));

assertThat(e)
.hasMessageThat()
.isEqualTo("ERROR: <input>:-1:0: AST's depth exceeds the configured limit: 50.");
}

@Test
public void compileYamlPolicy_astDepthLimitCheckDisabled_doesNotThrow() throws Exception {
String longExpr =
"0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50";
String policyContent =
String.format(
"name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr);
CelPolicy policy = POLICY_PARSER.parse(policyContent);

CelAbstractSyntaxTree ast =
CelPolicyCompilerFactory.newPolicyCompiler(newCel())
.setAstDepthLimit(-1)
.build()
.compile(policy);
assertThat(ast).isNotNull();
}

@Test
@SuppressWarnings("unchecked")
public void evaluateYamlPolicy_withCanonicalTestData(
Expand Down
10 changes: 8 additions & 2 deletions validator/src/main/java/dev/cel/validator/CelAstValidator.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import dev.cel.common.CelIssue;
import dev.cel.common.CelIssue.Severity;
import dev.cel.common.CelSource;
import dev.cel.common.CelSourceLocation;
import dev.cel.common.navigation.CelNavigableAst;
import java.util.Optional;

/** Public interface for performing a single, custom validation on an AST. */
public interface CelAstValidator {
Expand Down Expand Up @@ -53,12 +55,16 @@ public void addInfo(long exprId, String message) {

private void add(long exprId, String message, Severity severity) {
CelSource source = navigableAst.getAst().getSource();
int position = source.getPositionsMap().get(exprId);
int position = Optional.ofNullable(source.getPositionsMap().get(exprId)).orElse(-1);
CelSourceLocation sourceLocation = CelSourceLocation.NONE;
if (position >= 0) {
sourceLocation = source.getOffsetLocation(position).get();
}
issuesBuilder.add(
CelIssue.newBuilder()
.setSeverity(severity)
.setMessage(message)
.setSourceLocation(source.getOffsetLocation(position).get())
.setSourceLocation(sourceLocation)
.build());
}

Expand Down
Loading