diff --git a/policy/src/main/java/dev/cel/policy/BUILD.bazel b/policy/src/main/java/dev/cel/policy/BUILD.bazel index ea3f4353..bdb651d0 100644 --- a/policy/src/main/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/main/java/dev/cel/policy/BUILD.bazel @@ -257,6 +257,10 @@ java_library( "//optimizer", "//optimizer:optimization_exception", "//optimizer:optimizer_builder", + "//validator", + "//validator:ast_validator", + "//validator:validator_builder", + "//validator/validators:ast_depth_limit_validator", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java index 13bac288..592a0120 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java @@ -31,6 +31,13 @@ public interface CelPolicyCompilerBuilder { @CanIgnoreReturnValue CelPolicyCompilerBuilder setIterationLimit(int iterationLimit); + /** + * Enforces the composed AST to stay below the configured depth limit. An exception is thrown if + * the depth exceeds the configured limit. Setting a negative value disables this check. + */ + @CanIgnoreReturnValue + CelPolicyCompilerBuilder setAstDepthLimit(int iterationLimit); + @CheckReturnValue CelPolicyCompiler build(); } diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java index b7d1377b..ca1be9f4 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java @@ -25,6 +25,7 @@ import dev.cel.common.CelSource; import dev.cel.common.CelSourceLocation; import dev.cel.common.CelValidationException; +import dev.cel.common.CelValidationResult; import dev.cel.common.CelVarDecl; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; @@ -40,6 +41,10 @@ import dev.cel.policy.CelPolicy.Match; import dev.cel.policy.CelPolicy.Variable; import dev.cel.policy.RuleComposer.RuleCompositionException; +import dev.cel.validator.CelAstValidator; +import dev.cel.validator.CelValidator; +import dev.cel.validator.CelValidatorFactory; +import dev.cel.validator.validators.AstDepthLimitValidator; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -52,6 +57,7 @@ final class CelPolicyCompilerImpl implements CelPolicyCompiler { private final Cel cel; private final String variablesPrefix; private final int iterationLimit; + private final Optional astDepthValidator; @Override public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationException { @@ -67,8 +73,9 @@ public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationE @Override public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledRule) throws CelPolicyValidationException { + Cel cel = compiledRule.cel(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(compiledRule.cel()) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( RuleComposer.newInstance(compiledRule, variablesPrefix, iterationLimit)) .build(); @@ -105,9 +112,26 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR throw new CelPolicyValidationException("Unexpected error while composing rules.", e); } + assertAstDepthIsSafe(ast, cel); + return ast; } + private void assertAstDepthIsSafe(CelAbstractSyntaxTree ast, Cel cel) + throws CelPolicyValidationException { + if (!astDepthValidator.isPresent()) { + return; + } + CelValidator celValidator = + CelValidatorFactory.standardCelValidatorBuilder(cel) + .addAstValidators(astDepthValidator.get()) + .build(); + CelValidationResult result = celValidator.validate(ast); + if (result.hasError()) { + throw new CelPolicyValidationException(result.getErrorString()); + } + } + private CelCompiledRule compileRuleImpl( CelPolicy.Rule rule, Cel ruleCel, CompilerContext compilerContext) { ImmutableList.Builder variableBuilder = ImmutableList.builder(); @@ -262,9 +286,11 @@ static final class Builder implements CelPolicyCompilerBuilder { private final Cel cel; private String variablesPrefix; private int iterationLimit; + private Optional astDepthLimitValidator; private Builder(Cel cel) { this.cel = cel; + this.astDepthLimitValidator = Optional.of(AstDepthLimitValidator.DEFAULT); } @Override @@ -281,9 +307,21 @@ public Builder setIterationLimit(int iterationLimit) { return this; } + @Override + @CanIgnoreReturnValue + public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) { + if (astDepthLimit < 0) { + astDepthLimitValidator = Optional.empty(); + } else { + astDepthLimitValidator = Optional.of(AstDepthLimitValidator.newInstance(astDepthLimit)); + } + return this; + } + @Override public CelPolicyCompiler build() { - return new CelPolicyCompilerImpl(cel, this.variablesPrefix, this.iterationLimit); + return new CelPolicyCompilerImpl( + cel, this.variablesPrefix, this.iterationLimit, astDepthLimitValidator); } } @@ -293,9 +331,14 @@ static Builder newBuilder(Cel cel) { .setIterationLimit(DEFAULT_ITERATION_LIMIT); } - private CelPolicyCompilerImpl(Cel cel, String variablesPrefix, int iterationLimit) { + private CelPolicyCompilerImpl( + Cel cel, + String variablesPrefix, + int iterationLimit, + Optional astDepthValidator) { this.cel = checkNotNull(cel); this.variablesPrefix = checkNotNull(variablesPrefix); this.iterationLimit = iterationLimit; + this.astDepthValidator = astDepthValidator; } } diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index 7d4bb85d..bf435f90 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -105,6 +105,42 @@ public void compileYamlPolicy_multilineContainsError_throws( assertThat(e).hasMessageThat().isEqualTo(testCase.expected); } + @Test + public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Exception { + String longExpr = + "0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50"; + String policyContent = + String.format( + "name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr); + CelPolicy policy = POLICY_PARSER.parse(policyContent); + + CelPolicyValidationException e = + assertThrows( + CelPolicyValidationException.class, + () -> CelPolicyCompilerFactory.newPolicyCompiler(newCel()).build().compile(policy)); + + assertThat(e) + .hasMessageThat() + .isEqualTo("ERROR: :-1:0: AST's depth exceeds the configured limit: 50."); + } + + @Test + public void compileYamlPolicy_astDepthLimitCheckDisabled_doesNotThrow() throws Exception { + String longExpr = + "0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50"; + String policyContent = + String.format( + "name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr); + CelPolicy policy = POLICY_PARSER.parse(policyContent); + + CelAbstractSyntaxTree ast = + CelPolicyCompilerFactory.newPolicyCompiler(newCel()) + .setAstDepthLimit(-1) + .build() + .compile(policy); + assertThat(ast).isNotNull(); + } + @Test @SuppressWarnings("unchecked") public void evaluateYamlPolicy_withCanonicalTestData( diff --git a/validator/src/main/java/dev/cel/validator/CelAstValidator.java b/validator/src/main/java/dev/cel/validator/CelAstValidator.java index ca316019..d3b04668 100644 --- a/validator/src/main/java/dev/cel/validator/CelAstValidator.java +++ b/validator/src/main/java/dev/cel/validator/CelAstValidator.java @@ -19,7 +19,9 @@ import dev.cel.common.CelIssue; import dev.cel.common.CelIssue.Severity; import dev.cel.common.CelSource; +import dev.cel.common.CelSourceLocation; import dev.cel.common.navigation.CelNavigableAst; +import java.util.Optional; /** Public interface for performing a single, custom validation on an AST. */ public interface CelAstValidator { @@ -53,12 +55,16 @@ public void addInfo(long exprId, String message) { private void add(long exprId, String message, Severity severity) { CelSource source = navigableAst.getAst().getSource(); - int position = source.getPositionsMap().get(exprId); + int position = Optional.ofNullable(source.getPositionsMap().get(exprId)).orElse(-1); + CelSourceLocation sourceLocation = CelSourceLocation.NONE; + if (position >= 0) { + sourceLocation = source.getOffsetLocation(position).get(); + } issuesBuilder.add( CelIssue.newBuilder() .setSeverity(severity) .setMessage(message) - .setSourceLocation(source.getOffsetLocation(position).get()) + .setSourceLocation(sourceLocation) .build()); }