Skip to content

Commit

Permalink
Add option to specify folding designated custom functions only
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653874748
  • Loading branch information
l46kok authored and copybara-github committed Jul 19, 2024
1 parent 995243c commit 882f631
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 36 deletions.
1 change: 1 addition & 0 deletions extensions/src/main/java/dev/cel/extensions/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ java_library(
"//compiler:compiler_builder",
"//runtime",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package dev.cel.extensions;

import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.Immutable;
import com.google.protobuf.ByteString;
import dev.cel.checker.CelCheckerBuilder;
Expand All @@ -35,31 +36,58 @@ public class CelEncoderExtensions implements CelCompilerLibrary, CelRuntimeLibra

private static final Decoder BASE64_DECODER = Base64.getDecoder();

@Override
public void setCheckerOptions(CelCheckerBuilder checkerBuilder) {
checkerBuilder.addFunctionDeclarations(
private final ImmutableSet<Function> functions;

enum Function {
DECODE(
CelFunctionDecl.newFunctionDeclaration(
"base64.decode",
CelOverloadDecl.newGlobalOverload(
"base64_decode_string", SimpleType.BYTES, SimpleType.STRING)),
ImmutableSet.of(
CelRuntime.CelFunctionBinding.from(
"base64_decode_string",
String.class,
str -> ByteString.copyFrom(BASE64_DECODER.decode(str))))),
ENCODE(
CelFunctionDecl.newFunctionDeclaration(
"base64.encode",
CelOverloadDecl.newGlobalOverload(
"base64_encode_bytes", SimpleType.STRING, SimpleType.BYTES)));
"base64_encode_bytes", SimpleType.STRING, SimpleType.BYTES)),
ImmutableSet.of(
CelRuntime.CelFunctionBinding.from(
"base64_encode_bytes",
ByteString.class,
bytes -> BASE64_ENCODER.encodeToString(bytes.toByteArray())))),
;

private final CelFunctionDecl functionDecl;
private final ImmutableSet<CelRuntime.CelFunctionBinding> functionBindings;

String getFunction() {
return functionDecl.name();
}

Function(
CelFunctionDecl functionDecl,
ImmutableSet<CelRuntime.CelFunctionBinding> functionBindings) {
this.functionDecl = functionDecl;
this.functionBindings = functionBindings;
}
}

@Override
public void setCheckerOptions(CelCheckerBuilder checkerBuilder) {
functions.forEach(function -> checkerBuilder.addFunctionDeclarations(function.functionDecl));
}

@SuppressWarnings("Immutable") // Instances of java.util.Base64 are immutable
@Override
public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
runtimeBuilder.addFunctionBindings(
CelRuntime.CelFunctionBinding.from(
"base64_decode_string",
String.class,
str -> ByteString.copyFrom(BASE64_DECODER.decode(str))),
CelRuntime.CelFunctionBinding.from(
"base64_encode_bytes",
ByteString.class,
bytes -> BASE64_ENCODER.encodeToString(bytes.toByteArray())));
functions.forEach(function -> runtimeBuilder.addFunctionBindings(function.functionBindings));
}
}

public CelEncoderExtensions() {
this.functions = ImmutableSet.copyOf(Function.values());
}
}
24 changes: 24 additions & 0 deletions extensions/src/main/java/dev/cel/extensions/CelExtensions.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

package dev.cel.extensions;

import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Arrays.stream;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import dev.cel.common.CelOptions;
import java.util.Set;

Expand Down Expand Up @@ -206,5 +210,25 @@ public static CelSetsExtensions sets(Set<CelSetsExtensions.Function> functions)
return new CelSetsExtensions(functions);
}

/**
* Retrieves all function names used by every extension libraries.
*
* <p>Note: Certain extensions such as {@link CelProtoExtensions} and {@link
* CelBindingsExtensions} are implemented via macros, not functions, and those are not included
* here.
*/
public static ImmutableSet<String> getAllFunctionNames() {
return Streams.concat(
stream(CelMathExtensions.Function.values())
.map(CelMathExtensions.Function::getFunction),
stream(CelStringExtensions.Function.values())
.map(CelStringExtensions.Function::getFunction),
stream(CelSetsExtensions.Function.values())
.map(CelSetsExtensions.Function::getFunction),
stream(CelEncoderExtensions.Function.values())
.map(CelEncoderExtensions.Function::getFunction))
.collect(toImmutableSet());
}

private CelExtensions() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ final class CelMathExtensions implements CelCompilerLibrary, CelRuntimeLibrary {
return builder.buildOrThrow();
}

public enum Function {
enum Function {
MAX(
CelFunctionDecl.newFunctionDeclaration(
MATH_MAX_FUNCTION,
Expand Down Expand Up @@ -341,6 +341,10 @@ public enum Function {
private final ImmutableSet<CelRuntime.CelFunctionBinding> functionBindingsULongSigned;
private final ImmutableSet<CelRuntime.CelFunctionBinding> functionBindingsULongUnsigned;

String getFunction() {
return functionDecl.name();
}

Function(
CelFunctionDecl functionDecl,
ImmutableSet<CelRuntime.CelFunctionBinding> functionBindings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ public enum Function {
private final CelFunctionDecl functionDecl;
private final ImmutableSet<CelRuntime.CelFunctionBinding> functionBindings;

String getFunction() {
return functionDecl.name();
}

Function(CelFunctionDecl functionDecl, CelRuntime.CelFunctionBinding... functionBindings) {
this.functionDecl = functionDecl;
this.functionBindings = ImmutableSet.copyOf(functionBindings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ public enum Function {
private final CelFunctionDecl functionDecl;
private final ImmutableSet<CelRuntime.CelFunctionBinding> functionBindings;

String getFunction() {
return functionDecl.name();
}

Function(CelFunctionDecl functionDecl, CelRuntime.CelFunctionBinding... functionBindings) {
this.functionDecl = functionDecl;
this.functionBindings = ImmutableSet.copyOf(functionBindings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,27 @@ public void addEncoderExtension_success() throws Exception {

assertThat(evaluatedResult).isTrue();
}
}

@Test
public void getAllFunctionNames() {
assertThat(CelExtensions.getAllFunctionNames())
.containsExactly(
"math.@max",
"math.@min",
"charAt",
"indexOf",
"join",
"lastIndexOf",
"lowerAscii",
"replace",
"split",
"substring",
"trim",
"upperAscii",
"sets.contains",
"sets.equivalent",
"sets.intersects",
"base64.decode",
"base64.encode");
}
}
21 changes: 18 additions & 3 deletions optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ java_library(
tags = [
],
deps = [
":default_optimizer_constants",
"//:auto_value",
"//bundle:cel",
"//common",
Expand All @@ -29,6 +30,7 @@ java_library(
"//optimizer:optimization_exception",
"//parser:operator",
"//runtime",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
],
)
Expand All @@ -41,9 +43,9 @@ java_library(
tags = [
],
deps = [
":default_optimizer_constants",
"//:auto_value",
"//bundle:cel",
"//checker:checker_legacy_environment",
"//common",
"//common:compiler_common",
"//common:mutable_ast",
Expand All @@ -55,12 +57,25 @@ java_library(
"//common/navigation:mutable_navigation",
"//common/types",
"//common/types:type_providers",
"//extensions:optional_library",
"//optimizer:ast_optimizer",
"//optimizer:mutable_ast",
"//parser:operator",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
"@maven//:org_jspecify_jspecify",
],
)

java_library(
name = "default_optimizer_constants",
srcs = [
"DefaultOptimizerConstants.java",
],
visibility = ["//visibility:private"],
deps = [
"//checker:checker_legacy_environment",
"//extensions",
"//extensions:optional_library",
"//parser:operator",
"@maven//:com_google_guava_guava",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
// limitations under the License.
package dev.cel.optimizer.optimizers;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.MoreCollectors.onlyElement;

import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import dev.cel.bundle.Cel;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelMutableAst;
Expand All @@ -40,6 +43,7 @@
import dev.cel.parser.Operator;
import dev.cel.runtime.CelEvaluationException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -71,6 +75,7 @@ public static ConstantFoldingOptimizer newInstance(

private final ConstantFoldingOptions constantFoldingOptions;
private final AstMutator astMutator;
private final ImmutableSet<String> foldableFunctions;

// Use optional.of and optional.none as sentinel function names for folding optional calls.
// TODO: Leverage CelValue representation of Optionals instead when available.
Expand All @@ -95,7 +100,7 @@ public OptimizationResult optimize(CelAbstractSyntaxTree ast, Cel cel)
CelNavigableMutableAst.fromAst(mutableAst)
.getRoot()
.allNodes()
.filter(ConstantFoldingOptimizer::canFold)
.filter(this::canFold)
.collect(toImmutableList());
for (CelNavigableMutableExpr foldableExpr : foldableExprs) {
iterCount++;
Expand Down Expand Up @@ -124,9 +129,13 @@ public OptimizationResult optimize(CelAbstractSyntaxTree ast, Cel cel)
return OptimizationResult.create(astMutator.renumberIdsConsecutively(mutableAst).toParsedAst());
}

private static boolean canFold(CelNavigableMutableExpr navigableExpr) {
private boolean canFold(CelNavigableMutableExpr navigableExpr) {
switch (navigableExpr.getKind()) {
case CALL:
if (!containsFoldableFunctionOnly(navigableExpr)) {
return false;
}

CelMutableCall mutableCall = navigableExpr.expr().call();
String functionName = mutableCall.function();

Expand Down Expand Up @@ -169,6 +178,19 @@ private static boolean canFold(CelNavigableMutableExpr navigableExpr) {
}
}

private boolean containsFoldableFunctionOnly(CelNavigableMutableExpr navigableExpr) {
return navigableExpr
.allNodes()
.allMatch(
node -> {
if (node.getKind().equals(Kind.CALL)) {
return foldableFunctions.contains(node.expr().call().function());
}

return true;
});
}

private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr) {
ImmutableList<CelNavigableMutableExpr> allIdents =
navigableExpr
Expand Down Expand Up @@ -574,16 +596,39 @@ private CelMutableAst pruneOptionalStructElements(CelMutableAst ast, CelMutableE
public abstract static class ConstantFoldingOptions {
public abstract int maxIterationLimit();

public abstract ImmutableSet<String> foldableFunctions();

/** Builder for configuring the {@link ConstantFoldingOptions}. */
@AutoValue.Builder
public abstract static class Builder {

abstract ImmutableSet.Builder<String> foldableFunctionsBuilder();

/**
* Limit the number of iteration while performing constant folding. An exception is thrown if
* the iteration count exceeds the set value.
*/
public abstract Builder maxIterationLimit(int value);

/**
* Adds a collection of custom functions that will be a candidate for constant folding. By
* default, standard functions are foldable.
*
* <p>Note that the implementation of custom functions must be free of side effects.
*/
@CanIgnoreReturnValue
public Builder addFoldableFunctions(Iterable<String> functions) {
checkNotNull(functions);
this.foldableFunctionsBuilder().addAll(functions);
return this;
}

/** See {@link #addFoldableFunctions(Iterable)}. */
@CanIgnoreReturnValue
public Builder addFoldableFunctions(String... functions) {
return addFoldableFunctions(Arrays.asList(functions));
}

public abstract ConstantFoldingOptions build();

Builder() {}
Expand All @@ -601,5 +646,10 @@ public static Builder newBuilder() {
private ConstantFoldingOptimizer(ConstantFoldingOptions constantFoldingOptions) {
this.constantFoldingOptions = constantFoldingOptions;
this.astMutator = AstMutator.newInstance(constantFoldingOptions.maxIterationLimit());
this.foldableFunctions =
ImmutableSet.<String>builder()
.addAll(DefaultOptimizerConstants.CEL_CANONICAL_FUNCTIONS)
.addAll(constantFoldingOptions.foldableFunctions())
.build();
}
}
Loading

0 comments on commit 882f631

Please sign in to comment.