Skip to content

Commit

Permalink
Fix runtime equality behavior for sets extension
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662216130
  • Loading branch information
l46kok authored and copybara-github committed Aug 13, 2024
1 parent 8073b79 commit 81ea911
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 86 deletions.
5 changes: 4 additions & 1 deletion extensions/src/main/java/dev/cel/extensions/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,13 @@ java_library(
deps = [
"//checker:checker_builder",
"//common:compiler_common",
"//common/internal:comparison_functions",
"//common:options",
"//common/internal:default_message_factory",
"//common/internal:dynamic_proto",
"//common/types",
"//compiler:compiler_builder",
"//runtime",
"//runtime:runtime_helper",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
],
Expand Down
23 changes: 16 additions & 7 deletions extensions/src/main/java/dev/cel/extensions/CelExtensions.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ public final class CelExtensions {
private static final CelProtoExtensions PROTO_EXTENSIONS = new CelProtoExtensions();
private static final CelBindingsExtensions BINDINGS_EXTENSIONS = new CelBindingsExtensions();
private static final CelEncoderExtensions ENCODER_EXTENSIONS = new CelEncoderExtensions();
private static final CelSetsExtensions SET_EXTENSIONS = new CelSetsExtensions();

/**
* Extended functions for string manipulation.
Expand Down Expand Up @@ -175,6 +174,14 @@ public static CelEncoderExtensions encoders() {
return ENCODER_EXTENSIONS;
}

/**
* @deprecated Use {@link #sets(CelOptions)} instead.
*/
@Deprecated
public static CelSetsExtensions sets() {
return sets(CelOptions.DEFAULT);
}

/**
* Extended functions for Set manipulation.
*
Expand All @@ -184,8 +191,8 @@ public static CelEncoderExtensions encoders() {
* future additions. To expose only a subset of functions, use {@link
* #sets(CelSetExtensions.Function...)} instead.
*/
public static CelSetsExtensions sets() {
return SET_EXTENSIONS;
public static CelSetsExtensions sets(CelOptions celOptions) {
return new CelSetsExtensions(celOptions);
}

/**
Expand All @@ -195,8 +202,9 @@ public static CelSetsExtensions sets() {
*
* <p>This will include only the specific functions denoted by {@link CelSetsExtensions.Function}.
*/
public static CelSetsExtensions sets(CelSetsExtensions.Function... functions) {
return sets(ImmutableSet.copyOf(functions));
public static CelSetsExtensions sets(
CelOptions celOptions, CelSetsExtensions.Function... functions) {
return sets(celOptions, ImmutableSet.copyOf(functions));
}

/**
Expand All @@ -206,8 +214,9 @@ public static CelSetsExtensions sets(CelSetsExtensions.Function... functions) {
*
* <p>This will include only the specific functions denoted by {@link CelSetsExtensions.Function}.
*/
public static CelSetsExtensions sets(Set<CelSetsExtensions.Function> functions) {
return new CelSetsExtensions(functions);
public static CelSetsExtensions sets(
CelOptions celOptions, Set<CelSetsExtensions.Function> functions) {
return new CelSetsExtensions(celOptions, functions);
}

/**
Expand Down
121 changes: 51 additions & 70 deletions extensions/src/main/java/dev/cel/extensions/CelSetsExtensions.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
import com.google.errorprone.annotations.Immutable;
import dev.cel.checker.CelCheckerBuilder;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOptions;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.internal.ComparisonFunctions;
import dev.cel.common.internal.DefaultMessageFactory;
import dev.cel.common.internal.DynamicProto;
import dev.cel.common.types.ListType;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.TypeParamType;
import dev.cel.compiler.CelCompilerLibrary;
import dev.cel.runtime.CelRuntime;
import dev.cel.runtime.CelRuntimeBuilder;
import dev.cel.runtime.CelRuntimeLibrary;
import dev.cel.runtime.RuntimeEquality;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/**
Expand Down Expand Up @@ -64,6 +66,9 @@ public final class CelSetsExtensions implements CelCompilerLibrary, CelRuntimeLi
+ " are unique, so size does not factor into the computation. If either list is empty,"
+ " the result will be false.";

private static final RuntimeEquality RUNTIME_EQUALITY =
new RuntimeEquality(DynamicProto.create(DefaultMessageFactory.INSTANCE));

/** Denotes the set extension function. */
public enum Function {
CONTAINS(
Expand All @@ -74,12 +79,7 @@ public enum Function {
SET_CONTAINS_OVERLOAD_DOC,
SimpleType.BOOL,
ListType.create(TypeParamType.create("T")),
ListType.create(TypeParamType.create("T")))),
CelRuntime.CelFunctionBinding.from(
"list_sets_contains_list",
Collection.class,
Collection.class,
CelSetsExtensions::containsAll)),
ListType.create(TypeParamType.create("T"))))),
EQUIVALENT(
CelFunctionDecl.newFunctionDeclaration(
SET_EQUIVALENT_FUNCTION,
Expand All @@ -88,12 +88,7 @@ public enum Function {
SET_EQUIVALENT_OVERLOAD_DOC,
SimpleType.BOOL,
ListType.create(TypeParamType.create("T")),
ListType.create(TypeParamType.create("T")))),
CelRuntime.CelFunctionBinding.from(
"list_sets_equivalent_list",
Collection.class,
Collection.class,
(listA, listB) -> containsAll(listA, listB) && containsAll(listB, listA))),
ListType.create(TypeParamType.create("T"))))),
INTERSECTS(
CelFunctionDecl.newFunctionDeclaration(
SET_INTERSECTS_FUNCTION,
Expand All @@ -102,34 +97,29 @@ public enum Function {
SET_INTERSECTS_OVERLOAD_DOC,
SimpleType.BOOL,
ListType.create(TypeParamType.create("T")),
ListType.create(TypeParamType.create("T")))),
CelRuntime.CelFunctionBinding.from(
"list_sets_intersects_list",
Collection.class,
Collection.class,
CelSetsExtensions::setIntersects));
ListType.create(TypeParamType.create("T")))));

private final CelFunctionDecl functionDecl;
private final ImmutableSet<CelRuntime.CelFunctionBinding> functionBindings;

String getFunction() {
return functionDecl.name();
}

Function(CelFunctionDecl functionDecl, CelRuntime.CelFunctionBinding... functionBindings) {
Function(CelFunctionDecl functionDecl) {
this.functionDecl = functionDecl;
this.functionBindings = ImmutableSet.copyOf(functionBindings);
}
}

private final ImmutableSet<Function> functions;
private final CelOptions celOptions;

CelSetsExtensions() {
this(ImmutableSet.copyOf(Function.values()));
CelSetsExtensions(CelOptions celOptions) {
this(celOptions, ImmutableSet.copyOf(Function.values()));
}

CelSetsExtensions(Set<Function> functions) {
CelSetsExtensions(CelOptions celOptions, Set<Function> functions) {
this.functions = ImmutableSet.copyOf(functions);
this.celOptions = celOptions;
}

@Override
Expand All @@ -139,7 +129,34 @@ public void setCheckerOptions(CelCheckerBuilder checkerBuilder) {

@Override
public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
functions.forEach(function -> runtimeBuilder.addFunctionBindings(function.functionBindings));
for (Function function : functions) {
switch (function) {
case CONTAINS:
runtimeBuilder.addFunctionBindings(
CelRuntime.CelFunctionBinding.from(
"list_sets_contains_list",
Collection.class,
Collection.class,
this::containsAll));
break;
case EQUIVALENT:
runtimeBuilder.addFunctionBindings(
CelRuntime.CelFunctionBinding.from(
"list_sets_equivalent_list",
Collection.class,
Collection.class,
(listA, listB) -> containsAll(listA, listB) && containsAll(listB, listA)));
break;
case INTERSECTS:
runtimeBuilder.addFunctionBindings(
CelRuntime.CelFunctionBinding.from(
"list_sets_intersects_list",
Collection.class,
Collection.class,
this::setIntersects));
break;
}
}
}

/**
Expand All @@ -150,9 +167,9 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
* <p>This is picked verbatim as implemented in the Java standard library
* Collections.containsAll() method.
*
* @see #contains(Object)
* @see #contains(Object, Collection)
*/
private static <T> boolean containsAll(Collection<T> list, Collection<T> subList) {
private <T> boolean containsAll(Collection<T> list, Collection<T> subList) {
for (T e : subList) {
if (!contains(e, list)) {
return false;
Expand All @@ -171,7 +188,7 @@ private static <T> boolean containsAll(Collection<T> list, Collection<T> subList
* <p>Source:
* https://hg.openjdk.org/jdk8u/jdk8u-dev/jdk/file/c5d02f908fb2/src/share/classes/java/util/AbstractCollection.java#l98
*/
private static <T> boolean contains(Object o, Collection<T> list) {
private <T> boolean contains(Object o, Collection<T> list) {
Iterator<?> it = list.iterator();
if (o == null) {
while (it.hasNext()) {
Expand All @@ -182,55 +199,19 @@ private static <T> boolean contains(Object o, Collection<T> list) {
} else {
while (it.hasNext()) {
Object item = it.next();
if (objectsEquals(item, o)) { // TODO: Support Maps.
if (objectsEquals(item, o)) {
return true;
}
}
}
return false;
}

private static boolean objectsEquals(Object o1, Object o2) {
if (o1 == o2) {
return true;
}
if (o1 == null || o2 == null) {
return false;
}
if (isNumeric(o1) && isNumeric(o2)) {
if (o1.getClass().equals(o2.getClass())) {
return o1.equals(o2);
}
return ComparisonFunctions.numericEquals((Number) o1, (Number) o2);
}
if (isList(o1) && isList(o2)) {
Collection<?> list1 = (Collection<?>) o1;
Collection<?> list2 = (Collection<?>) o2;
if (list1.size() != list2.size()) {
return false;
}
Iterator<?> iterator1 = list1.iterator();
Iterator<?> iterator2 = list2.iterator();
boolean result = true;
while (iterator1.hasNext() && iterator2.hasNext()) {
Object p1 = iterator1.next();
Object p2 = iterator2.next();
result = result && objectsEquals(p1, p2);
}
return result;
}
return o1.equals(o2);
}

private static boolean isNumeric(Object o) {
return o instanceof Number;
}

private static boolean isList(Object o) {
return o instanceof List;
private boolean objectsEquals(Object o1, Object o2) {
return RUNTIME_EQUALITY.objectEquals(o1, o2, celOptions);
}

private static <T> boolean setIntersects(Collection<T> listA, Collection<T> listB) {
private <T> boolean setIntersects(Collection<T> listA, Collection<T> listB) {
if (listA.isEmpty() || listB.isEmpty()) {
return false;
}
Expand Down
Loading

0 comments on commit 81ea911

Please sign in to comment.