diff --git a/common/src/main/java/dev/cel/common/CelSource.java b/common/src/main/java/dev/cel/common/CelSource.java index 3f4c0d2d7..62a74873a 100644 --- a/common/src/main/java/dev/cel/common/CelSource.java +++ b/common/src/main/java/dev/cel/common/CelSource.java @@ -184,6 +184,7 @@ public Builder toBuilder() { return new Builder(codePoints, lineOffsets) .setDescription(description) .addPositionsMap(positions) + .addAllExtensions(extensions) .addAllMacroCalls(macroCalls); } @@ -354,7 +355,7 @@ private LineAndOffset(int line, int offset) { */ @AutoValue @Immutable - abstract static class Extension { + public abstract static class Extension { /** Identifier for the extension. Example: constant_folding */ abstract String id(); @@ -371,9 +372,10 @@ abstract static class Extension { */ abstract ImmutableList affectedComponents(); + /** Version of the extension */ @AutoValue @Immutable - abstract static class Version { + public abstract static class Version { /** * Major version changes indicate different required support level from the required @@ -388,13 +390,13 @@ abstract static class Version { abstract long minor(); /** Create a new instance of Version with the provided major and minor values. */ - static Version of(long major, long minor) { + public static Version of(long major, long minor) { return new AutoValue_CelSource_Extension_Version(major, minor); } } /** CEL component specifier. */ - enum Component { + public enum Component { /** Unspecified, default. */ COMPONENT_UNSPECIFIED, /** Parser. Converts a CEL string to an AST. */ @@ -406,14 +408,14 @@ enum Component { } @CheckReturnValue - static Extension create(String id, Version version, Iterable components) { + public static Extension create(String id, Version version, Iterable components) { checkNotNull(version); checkNotNull(components); return new AutoValue_CelSource_Extension(id, version, ImmutableList.copyOf(components)); } @CheckReturnValue - static Extension create(String id, Version version, Component... components) { + public static Extension create(String id, Version version, Component... components) { return create(id, version, Arrays.asList(components)); } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index 5562428ca..5798006c7 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -456,11 +456,12 @@ private CelSource normalizeMacroSource( ExprIdGenerator idGenerator) { // Remove the macro metadata that no longer exists in the AST due to being replaced. celSource = celSource.toBuilder().clearMacroCall(exprIdToReplace).build(); + CelSource.Builder sourceBuilder = + CelSource.newBuilder().addAllExtensions(celSource.getExtensions()); if (celSource.getMacroCalls().isEmpty()) { - return CelSource.newBuilder().build(); + return sourceBuilder.build(); } - CelSource.Builder sourceBuilder = CelSource.newBuilder(); ImmutableMap allExprs = CelNavigableExpr.fromExpr(mutatedRoot.build()) .allNodes() diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index 150a91dc5..b62267bd1 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -29,6 +29,9 @@ import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelSource; +import dev.cel.common.CelSource.Extension; +import dev.cel.common.CelSource.Extension.Component; +import dev.cel.common.CelSource.Extension.Version; import dev.cel.common.CelValidationException; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelCall; @@ -87,6 +90,10 @@ public class SubexpressionOptimizer implements CelAstOptimizer { stream(Operator.values()).map(Operator::getFunction), stream(Standard.Function.values()).map(Standard.Function::getFunction)) .collect(toImmutableSet()); + + private static final Extension CEL_BLOCK_AST_EXTENSION_TAG = + Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME); + private final SubexpressionOptimizerOptions cseOptions; private final MutableAst mutableAst; @@ -209,7 +216,16 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock( // Restore the expected result type the environment had prior to optimization. celBuilder.setResultType(resultType); - return astToModify; + + return tagAstExtension(astToModify); + } + + private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) { + // Tag the extension + CelSource.Builder celSourceBuilder = + ast.getSource().toBuilder().addAllExtensions(CEL_BLOCK_AST_EXTENSION_TAG); + + return CelAbstractSyntaxTree.newParsedAst(ast.getExpr(), celSourceBuilder.build()); } /** @@ -510,7 +526,9 @@ public abstract static class Builder { /** * Rewrites the optimized AST using cel.@block call instead of cascaded cel.bind macros, aimed - * to produce a more compact AST. + * to produce a more compact AST. {@link com.google.api.expr.SourceInfo.Extension} field will + * be populated in the AST to inform that special runtime support is required to evaluate the + * optimized expression. */ public abstract Builder enableCelBlock(boolean value); diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java index d027c6b8e..bd9d122cf 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -27,6 +27,9 @@ import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; +import dev.cel.common.CelSource; +import dev.cel.common.CelSource.Extension; +import dev.cel.common.CelSource.Extension.Version; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelCall; @@ -99,6 +102,7 @@ public void mutableAst_nonMacro_sourceCleared() throws Exception { assertThat(mutatedAst.getSource().getDescription()).isEmpty(); assertThat(mutatedAst.getSource().getLineOffsets()).isEmpty(); assertThat(mutatedAst.getSource().getPositionsMap()).isEmpty(); + assertThat(mutatedAst.getSource().getExtensions()).isEmpty(); assertThat(mutatedAst.getSource().getMacroCalls()).isEmpty(); } @@ -113,9 +117,26 @@ public void mutableAst_macro_sourceMacroCallsPopulated() throws Exception { assertThat(mutatedAst.getSource().getDescription()).isEmpty(); assertThat(mutatedAst.getSource().getLineOffsets()).isEmpty(); assertThat(mutatedAst.getSource().getPositionsMap()).isEmpty(); + assertThat(mutatedAst.getSource().getExtensions()).isEmpty(); assertThat(mutatedAst.getSource().getMacroCalls()).isNotEmpty(); } + @Test + public void mutableAst_astContainsTaggedExtension_retained() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("has(TestAllTypes{}.single_int32)").getAst(); + Extension extension = Extension.create("test", Version.of(1, 1)); + CelSource celSource = ast.getSource().toBuilder().addAllExtensions(extension).build(); + ast = + CelAbstractSyntaxTree.newCheckedAst( + ast.getExpr(), celSource, ast.getReferenceMap(), ast.getTypeMap()); + + CelAbstractSyntaxTree mutatedAst = + MUTABLE_AST.replaceSubtree( + ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); + + assertThat(mutatedAst.getSource().getExtensions()).containsExactly(extension); + } + @Test @TestParameters("{source: '[1].exists(x, x > 0)', expectedMacroCallSize: 1}") @TestParameters( diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 377756fc9..ae8bc7f16 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -30,6 +30,9 @@ import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; +import dev.cel.common.CelSource.Extension; +import dev.cel.common.CelSource.Extension.Component; +import dev.cel.common.CelSource.Extension.Version; import dev.cel.common.CelValidationException; import dev.cel.common.CelVarDecl; import dev.cel.common.ast.CelConstant; @@ -1109,6 +1112,27 @@ public void iterationLimitReached_throws(boolean enableCelBlock) throws Exceptio assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Max iteration count reached."); } + @Test + public void celBlock_astExtensionTagged() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers( + SubexpressionOptimizer.newInstance( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .build()), + ConstantFoldingOptimizer.getInstance()) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast); + + assertThat(optimizedAst.getSource().getExtensions()) + .containsExactly( + Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME)); + } + private enum BlockTestCase { BOOL_LITERAL("cel.block([true, false], index0 || index1)"), STRING_CONCAT("cel.block(['a' + 'b', index0 + 'c'], index1 + 'd') == 'abcd'"),