diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index da8aba9a..9cf66cf1 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -102,6 +102,7 @@ java_library( "//compiler:compiler_builder", "//runtime", "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", ], ) diff --git a/extensions/src/main/java/dev/cel/extensions/CelEncoderExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelEncoderExtensions.java index 46ad0826..bc882aa0 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelEncoderExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelEncoderExtensions.java @@ -14,6 +14,7 @@ package dev.cel.extensions; +import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.Immutable; import com.google.protobuf.ByteString; import dev.cel.checker.CelCheckerBuilder; @@ -35,31 +36,58 @@ public class CelEncoderExtensions implements CelCompilerLibrary, CelRuntimeLibra private static final Decoder BASE64_DECODER = Base64.getDecoder(); - @Override - public void setCheckerOptions(CelCheckerBuilder checkerBuilder) { - checkerBuilder.addFunctionDeclarations( + private final ImmutableSet functions; + + enum Function { + DECODE( CelFunctionDecl.newFunctionDeclaration( "base64.decode", CelOverloadDecl.newGlobalOverload( "base64_decode_string", SimpleType.BYTES, SimpleType.STRING)), + ImmutableSet.of( + CelRuntime.CelFunctionBinding.from( + "base64_decode_string", + String.class, + str -> ByteString.copyFrom(BASE64_DECODER.decode(str))))), + ENCODE( CelFunctionDecl.newFunctionDeclaration( "base64.encode", CelOverloadDecl.newGlobalOverload( - "base64_encode_bytes", SimpleType.STRING, SimpleType.BYTES))); + "base64_encode_bytes", SimpleType.STRING, SimpleType.BYTES)), + ImmutableSet.of( + CelRuntime.CelFunctionBinding.from( + "base64_encode_bytes", + ByteString.class, + bytes -> BASE64_ENCODER.encodeToString(bytes.toByteArray())))), + ; + + private final CelFunctionDecl functionDecl; + private final ImmutableSet functionBindings; + + String getFunction() { + return functionDecl.name(); + } + + Function( + CelFunctionDecl functionDecl, + ImmutableSet functionBindings) { + this.functionDecl = functionDecl; + this.functionBindings = functionBindings; + } + } + + @Override + public void setCheckerOptions(CelCheckerBuilder checkerBuilder) { + functions.forEach(function -> checkerBuilder.addFunctionDeclarations(function.functionDecl)); } @SuppressWarnings("Immutable") // Instances of java.util.Base64 are immutable @Override public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { - runtimeBuilder.addFunctionBindings( - CelRuntime.CelFunctionBinding.from( - "base64_decode_string", - String.class, - str -> ByteString.copyFrom(BASE64_DECODER.decode(str))), - CelRuntime.CelFunctionBinding.from( - "base64_encode_bytes", - ByteString.class, - bytes -> BASE64_ENCODER.encodeToString(bytes.toByteArray()))); + functions.forEach(function -> runtimeBuilder.addFunctionBindings(function.functionBindings)); } -} + public CelEncoderExtensions() { + this.functions = ImmutableSet.copyOf(Function.values()); + } +} diff --git a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java index 5515d6a8..66224628 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java @@ -14,7 +14,11 @@ package dev.cel.extensions; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Arrays.stream; + import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; import dev.cel.common.CelOptions; import java.util.Set; @@ -206,5 +210,25 @@ public static CelSetsExtensions sets(Set functions) return new CelSetsExtensions(functions); } + /** + * Retrieves all function names used by every extension libraries. + * + *

Note: Certain extensions such as {@link CelProtoExtensions} and {@link + * CelBindingsExtensions} are implemented via macros, not functions, and those are not included + * here. + */ + public static ImmutableSet getAllFunctionNames() { + return Streams.concat( + stream(CelMathExtensions.Function.values()) + .map(CelMathExtensions.Function::getFunction), + stream(CelStringExtensions.Function.values()) + .map(CelStringExtensions.Function::getFunction), + stream(CelSetsExtensions.Function.values()) + .map(CelSetsExtensions.Function::getFunction), + stream(CelEncoderExtensions.Function.values()) + .map(CelEncoderExtensions.Function::getFunction)) + .collect(toImmutableSet()); + } + private CelExtensions() {} } diff --git a/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java index c6423a05..ef1877b4 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java @@ -104,7 +104,7 @@ final class CelMathExtensions implements CelCompilerLibrary, CelRuntimeLibrary { return builder.buildOrThrow(); } - public enum Function { + enum Function { MAX( CelFunctionDecl.newFunctionDeclaration( MATH_MAX_FUNCTION, @@ -341,6 +341,10 @@ public enum Function { private final ImmutableSet functionBindingsULongSigned; private final ImmutableSet functionBindingsULongUnsigned; + String getFunction() { + return functionDecl.name(); + } + Function( CelFunctionDecl functionDecl, ImmutableSet functionBindings, diff --git a/extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java index 10fcbdd0..e410edad 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java @@ -112,6 +112,10 @@ public enum Function { private final CelFunctionDecl functionDecl; private final ImmutableSet functionBindings; + String getFunction() { + return functionDecl.name(); + } + Function(CelFunctionDecl functionDecl, CelRuntime.CelFunctionBinding... functionBindings) { this.functionDecl = functionDecl; this.functionBindings = ImmutableSet.copyOf(functionBindings); diff --git a/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java index 69fbad8d..473722b2 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java @@ -229,6 +229,10 @@ public enum Function { private final CelFunctionDecl functionDecl; private final ImmutableSet functionBindings; + String getFunction() { + return functionDecl.name(); + } + Function(CelFunctionDecl functionDecl, CelRuntime.CelFunctionBinding... functionBindings) { this.functionDecl = functionDecl; this.functionBindings = ImmutableSet.copyOf(functionBindings); diff --git a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java index 0b2e7b2c..d7b75f8d 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java @@ -138,5 +138,27 @@ public void addEncoderExtension_success() throws Exception { assertThat(evaluatedResult).isTrue(); } -} + @Test + public void getAllFunctionNames() { + assertThat(CelExtensions.getAllFunctionNames()) + .containsExactly( + "math.@max", + "math.@min", + "charAt", + "indexOf", + "join", + "lastIndexOf", + "lowerAscii", + "replace", + "split", + "substring", + "trim", + "upperAscii", + "sets.contains", + "sets.equivalent", + "sets.intersects", + "base64.decode", + "base64.encode"); + } +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index 2b0cc563..9b679d9b 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -14,6 +14,7 @@ java_library( tags = [ ], deps = [ + ":default_optimizer_constants", "//:auto_value", "//bundle:cel", "//common", @@ -29,6 +30,7 @@ java_library( "//optimizer:optimization_exception", "//parser:operator", "//runtime", + "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], ) @@ -41,9 +43,9 @@ java_library( tags = [ ], deps = [ + ":default_optimizer_constants", "//:auto_value", "//bundle:cel", - "//checker:checker_legacy_environment", "//common", "//common:compiler_common", "//common:mutable_ast", @@ -55,12 +57,25 @@ java_library( "//common/navigation:mutable_navigation", "//common/types", "//common/types:type_providers", - "//extensions:optional_library", "//optimizer:ast_optimizer", "//optimizer:mutable_ast", - "//parser:operator", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:org_jspecify_jspecify", ], ) + +java_library( + name = "default_optimizer_constants", + srcs = [ + "DefaultOptimizerConstants.java", + ], + visibility = ["//visibility:private"], + deps = [ + "//checker:checker_legacy_environment", + "//extensions", + "//extensions:optional_library", + "//parser:operator", + "@maven//:com_google_guava_guava", + ], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index 150f614d..7cf1ce1c 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -13,11 +13,14 @@ // limitations under the License. package dev.cel.optimizer.optimizers; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.onlyElement; import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelMutableAst; @@ -40,6 +43,7 @@ import dev.cel.parser.Operator; import dev.cel.runtime.CelEvaluationException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.List; @@ -71,6 +75,7 @@ public static ConstantFoldingOptimizer newInstance( private final ConstantFoldingOptions constantFoldingOptions; private final AstMutator astMutator; + private final ImmutableSet foldableFunctions; // Use optional.of and optional.none as sentinel function names for folding optional calls. // TODO: Leverage CelValue representation of Optionals instead when available. @@ -95,7 +100,7 @@ public OptimizationResult optimize(CelAbstractSyntaxTree ast, Cel cel) CelNavigableMutableAst.fromAst(mutableAst) .getRoot() .allNodes() - .filter(ConstantFoldingOptimizer::canFold) + .filter(this::canFold) .collect(toImmutableList()); for (CelNavigableMutableExpr foldableExpr : foldableExprs) { iterCount++; @@ -124,9 +129,13 @@ public OptimizationResult optimize(CelAbstractSyntaxTree ast, Cel cel) return OptimizationResult.create(astMutator.renumberIdsConsecutively(mutableAst).toParsedAst()); } - private static boolean canFold(CelNavigableMutableExpr navigableExpr) { + private boolean canFold(CelNavigableMutableExpr navigableExpr) { switch (navigableExpr.getKind()) { case CALL: + if (!containsFoldableFunctionOnly(navigableExpr)) { + return false; + } + CelMutableCall mutableCall = navigableExpr.expr().call(); String functionName = mutableCall.function(); @@ -169,6 +178,19 @@ private static boolean canFold(CelNavigableMutableExpr navigableExpr) { } } + private boolean containsFoldableFunctionOnly(CelNavigableMutableExpr navigableExpr) { + return navigableExpr + .allNodes() + .allMatch( + node -> { + if (node.getKind().equals(Kind.CALL)) { + return foldableFunctions.contains(node.expr().call().function()); + } + + return true; + }); + } + private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr) { ImmutableList allIdents = navigableExpr @@ -574,16 +596,39 @@ private CelMutableAst pruneOptionalStructElements(CelMutableAst ast, CelMutableE public abstract static class ConstantFoldingOptions { public abstract int maxIterationLimit(); + public abstract ImmutableSet foldableFunctions(); + /** Builder for configuring the {@link ConstantFoldingOptions}. */ @AutoValue.Builder public abstract static class Builder { + abstract ImmutableSet.Builder foldableFunctionsBuilder(); + /** * Limit the number of iteration while performing constant folding. An exception is thrown if * the iteration count exceeds the set value. */ public abstract Builder maxIterationLimit(int value); + /** + * Adds a collection of custom functions that will be a candidate for constant folding. By + * default, standard functions are foldable. + * + *

Note that the implementation of custom functions must be free of side effects. + */ + @CanIgnoreReturnValue + public Builder addFoldableFunctions(Iterable functions) { + checkNotNull(functions); + this.foldableFunctionsBuilder().addAll(functions); + return this; + } + + /** See {@link #addFoldableFunctions(Iterable)}. */ + @CanIgnoreReturnValue + public Builder addFoldableFunctions(String... functions) { + return addFoldableFunctions(Arrays.asList(functions)); + } + public abstract ConstantFoldingOptions build(); Builder() {} @@ -601,5 +646,10 @@ public static Builder newBuilder() { private ConstantFoldingOptimizer(ConstantFoldingOptions constantFoldingOptions) { this.constantFoldingOptions = constantFoldingOptions; this.astMutator = AstMutator.newInstance(constantFoldingOptions.maxIterationLimit()); + this.foldableFunctions = + ImmutableSet.builder() + .addAll(DefaultOptimizerConstants.CEL_CANONICAL_FUNCTIONS) + .addAll(constantFoldingOptions.foldableFunctions()) + .build(); } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/DefaultOptimizerConstants.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/DefaultOptimizerConstants.java new file mode 100644 index 00000000..07a3f062 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/DefaultOptimizerConstants.java @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer.optimizers; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Arrays.stream; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import dev.cel.checker.Standard; +import dev.cel.extensions.CelExtensions; +import dev.cel.extensions.CelOptionalLibrary; +import dev.cel.extensions.CelOptionalLibrary.Function; +import dev.cel.parser.Operator; + +/** + * Package-private class that holds constants that's generally applicable across canonical + * optimizers provided from CEL. + */ +final class DefaultOptimizerConstants { + + /** + * List of function names from standard functions and extension libraries. These are free of side + * effects, thus amenable for optimization. + */ + static final ImmutableSet CEL_CANONICAL_FUNCTIONS = + ImmutableSet.builder() + .addAll( + Streams.concat( + stream(Operator.values()).map(Operator::getFunction), + stream(Standard.Function.values()).map(Standard.Function::getFunction), + stream(CelOptionalLibrary.Function.values()).map(Function::getFunction)) + .collect(toImmutableSet())) + .addAll(CelExtensions.getAllFunctionNames()) + .build(); + + private DefaultOptimizerConstants() {} +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index 967c68d3..ecb5b22e 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -16,8 +16,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static java.util.Arrays.stream; import static java.util.stream.Collectors.toCollection; import com.google.auto.value.AutoValue; @@ -30,7 +28,6 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; -import dev.cel.checker.Standard; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelMutableAst; @@ -54,12 +51,9 @@ import dev.cel.common.types.CelType; import dev.cel.common.types.ListType; import dev.cel.common.types.SimpleType; -import dev.cel.extensions.CelOptionalLibrary; -import dev.cel.extensions.CelOptionalLibrary.Function; import dev.cel.optimizer.AstMutator; import dev.cel.optimizer.AstMutator.MangledComprehensionAst; import dev.cel.optimizer.CelAstOptimizer; -import dev.cel.parser.Operator; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -94,12 +88,7 @@ * */ public class SubexpressionOptimizer implements CelAstOptimizer { - private static final ImmutableSet CSE_DEFAULT_ELIMINABLE_FUNCTIONS = - Streams.concat( - stream(Operator.values()).map(Operator::getFunction), - stream(Standard.Function.values()).map(Standard.Function::getFunction), - stream(CelOptionalLibrary.Function.values()).map(Function::getFunction)) - .collect(toImmutableSet()); + private static final SubexpressionOptimizer INSTANCE = new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build()); private static final String BIND_IDENTIFIER_PREFIX = "@r"; @@ -738,7 +727,7 @@ private SubexpressionOptimizer(SubexpressionOptimizerOptions cseOptions) { this.astMutator = AstMutator.newInstance(cseOptions.iterationLimit()); this.cseEliminableFunctions = ImmutableSet.builder() - .addAll(CSE_DEFAULT_ELIMINABLE_FUNCTIONS) + .addAll(DefaultOptimizerConstants.CEL_CANONICAL_FUNCTIONS) .addAll(cseOptions.eliminableFunctions()) .build(); } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index 9813ee26..96040b50 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -17,12 +17,15 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableList; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; +import dev.cel.common.CelOverloadDecl; import dev.cel.common.types.ListType; import dev.cel.common.types.MapType; import dev.cel.common.types.SimpleType; @@ -35,6 +38,7 @@ import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; +import dev.cel.runtime.CelRuntime.CelFunctionBinding; import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; import org.junit.Test; import org.junit.runner.RunWith; @@ -47,10 +51,27 @@ public class ConstantFoldingOptimizerTest { .addVar("y", SimpleType.DYN) .addVar("list_var", ListType.create(SimpleType.STRING)) .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING)) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + .addFunctionBindings( + CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) .addMessageTypes(TestAllTypes.getDescriptor()) .setContainer("dev.cel.testing.testdata.proto3") - .addCompilerLibraries(CelExtensions.bindings(), CelOptionalLibrary.INSTANCE) - .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .addCompilerLibraries( + CelExtensions.bindings(), + CelOptionalLibrary.INSTANCE, + CelExtensions.math(CelOptions.DEFAULT), + CelExtensions.strings(), + CelExtensions.sets(), + CelExtensions.encoders()) + .addRuntimeLibraries( + CelOptionalLibrary.INSTANCE, + CelExtensions.math(CelOptions.DEFAULT), + CelExtensions.strings(), + CelExtensions.sets(), + CelExtensions.encoders()) .build(); private static final CelOptimizer CEL_OPTIMIZER = @@ -161,6 +182,10 @@ public class ConstantFoldingOptimizerTest { @TestParameters("{source: 'map_var[?\"key\"]', expected: 'map_var[?\"key\"]'}") @TestParameters("{source: '\"abc\" in list_var', expected: '\"abc\" in list_var'}") @TestParameters("{source: '[?optional.none(), [?optional.none()]]', expected: '[[]]'}") + @TestParameters("{source: 'math.greatest(1.0, 2, 3.0)', expected: '3.0'}") + @TestParameters("{source: '\"world\".charAt(1)', expected: '\"o\"'}") + @TestParameters("{source: 'base64.encode(b\"hello\")', expected: '\"aGVsbG8=\"'}") + @TestParameters("{source: 'sets.contains([1], [1])', expected: 'true'}") @TestParameters( "{source: 'cel.bind(r0, [1, 2, 3], cel.bind(r1, 1 in r0, r1))', expected: 'true'}") // TODO: Support folding lists with mixed types. This requires mutable lists. @@ -291,6 +316,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: '[optional.none()]'}") @TestParameters("{source: '[?x.?y]'}") @TestParameters("{source: 'TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}'}") + @TestParameters("{source: 'get_true() == get_true()'}") + @TestParameters("{source: 'get_true() == true'}") public void constantFold_noOp(String source) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); @@ -299,6 +326,21 @@ public void constantFold_noOp(String source) throws Exception { assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); } + @Test + public void constantFold_addFoldableFunction_success() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("get_true() == get_true()").getAst(); + ConstantFoldingOptions options = + ConstantFoldingOptions.newBuilder().addFoldableFunctions("get_true").build(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers(ConstantFoldingOptimizer.newInstance(options)) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo("true"); + } + @Test public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNotSet() throws Exception {