From 0838d2b35d6c0873b29cc30ed74424a6f6954bd0 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 22 Mar 2024 11:17:28 -0700 Subject: [PATCH] Accept a list as an argument to setParameterTypes in function overload decl PiperOrigin-RevId: 618235574 --- .../test/java/dev/cel/bundle/CelImplTest.java | 40 +++++++++++++++++-- .../src/main/java/dev/cel/checker/Env.java | 2 +- .../dev/cel/checker/CelOverloadDeclTest.java | 18 +++++++++ .../java/dev/cel/common/CelOverloadDecl.java | 2 +- 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index 2b3cf30bb..625d2d1a7 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -16,6 +16,9 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static dev.cel.common.CelFunctionDecl.newFunctionDeclaration; +import static dev.cel.common.CelOverloadDecl.newGlobalOverload; +import static dev.cel.common.CelOverloadDecl.newMemberOverload; import static org.junit.Assert.assertThrows; import dev.cel.expr.CheckedExpr; @@ -55,10 +58,8 @@ import dev.cel.checker.ProtoTypeMask; import dev.cel.checker.TypeProvider; import dev.cel.common.CelAbstractSyntaxTree; -import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelIssue; import dev.cel.common.CelOptions; -import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelProtoAbstractSyntaxTree; import dev.cel.common.CelValidationException; import dev.cel.common.CelValidationResult; @@ -97,6 +98,7 @@ import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -1800,9 +1802,9 @@ public boolean isAssignableFrom(CelType other) { CelFactory.standardCelBuilder() .addVar("x", SimpleType.INT) .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( + newFunctionDeclaration( "print", - CelOverloadDecl.newGlobalOverload( + newGlobalOverload( "print_overload", SimpleType.STRING, customType))) // The overload would accept either Int or CustomType @@ -1816,6 +1818,36 @@ public boolean isAssignableFrom(CelType other) { assertThat(result).isEqualTo("5"); } + @Test + @SuppressWarnings("unchecked") // test only + public void program_functionParamWithWellKnownType() throws Exception { + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + newFunctionDeclaration( + "hasStringValue", + newMemberOverload( + "struct_hasStringValue_string_string", + SimpleType.BOOL, + StructTypeReference.create("google.protobuf.Struct"), + SimpleType.STRING, + SimpleType.STRING))) + .addFunctionBindings( + CelFunctionBinding.from( + "struct_hasStringValue_string_string", + ImmutableList.of(Map.class, String.class, String.class), + args -> { + Map map = (Map) args[0]; + return map.containsKey(args[1]) && map.containsValue(args[2]); + })) + .build(); + CelAbstractSyntaxTree ast = cel.compile("{'a': 'b'}.hasStringValue('a', 'b')").getAst(); + + boolean result = (boolean) cel.createProgram(ast).eval(); + + assertThat(result).isTrue(); + } + @Test public void toBuilder_isImmutable() { CelBuilder celBuilder = CelFactory.standardCelBuilder(); diff --git a/checker/src/main/java/dev/cel/checker/Env.java b/checker/src/main/java/dev/cel/checker/Env.java index 8ce80f488..e1846f1b8 100644 --- a/checker/src/main/java/dev/cel/checker/Env.java +++ b/checker/src/main/java/dev/cel/checker/Env.java @@ -966,7 +966,7 @@ private static CelFunctionDecl sanitizeFunction(CelFunctionDecl func) { overloadBuilder.setResultType(getWellKnownType(resultType)); } - ImmutableSet.Builder parameterTypeBuilder = ImmutableSet.builder(); + ImmutableList.Builder parameterTypeBuilder = ImmutableList.builder(); for (CelType paramType : overloadBuilder.parameterTypes()) { if (isWellKnownType(paramType)) { parameterTypeBuilder.add(getWellKnownType(paramType)); diff --git a/checker/src/test/java/dev/cel/checker/CelOverloadDeclTest.java b/checker/src/test/java/dev/cel/checker/CelOverloadDeclTest.java index 41a11d045..f78a4fa62 100644 --- a/checker/src/test/java/dev/cel/checker/CelOverloadDeclTest.java +++ b/checker/src/test/java/dev/cel/checker/CelOverloadDeclTest.java @@ -20,6 +20,7 @@ import static dev.cel.common.CelOverloadDecl.newMemberOverload; import dev.cel.expr.Decl.FunctionDecl.Overload; +import com.google.common.collect.ImmutableList; import dev.cel.common.CelOverloadDecl; import dev.cel.common.types.CelTypes; import dev.cel.common.types.SimpleType; @@ -85,4 +86,21 @@ public void toProtoOverload_withTypeParams() { .containsExactly(CelTypes.STRING, CelTypes.DOUBLE, CelTypes.createTypeParam("B")); assertThat(protoOverload.getTypeParamsList()).containsExactly("A", "B"); } + + @Test + public void setParameterTypes_doesNotDedupe() { + CelOverloadDecl overloadDecl = + CelOverloadDecl.newBuilder() + .setParameterTypes( + ImmutableList.of( + SimpleType.STRING, SimpleType.STRING, SimpleType.STRING, SimpleType.INT)) + .setOverloadId("overload_id") + .setIsInstanceFunction(true) + .setResultType(SimpleType.DYN) + .build(); + + assertThat(overloadDecl.parameterTypes()) + .containsExactly(SimpleType.STRING, SimpleType.STRING, SimpleType.STRING, SimpleType.INT) + .inOrder(); + } } diff --git a/common/src/main/java/dev/cel/common/CelOverloadDecl.java b/common/src/main/java/dev/cel/common/CelOverloadDecl.java index 30bcac5a8..247f86bc1 100644 --- a/common/src/main/java/dev/cel/common/CelOverloadDecl.java +++ b/common/src/main/java/dev/cel/common/CelOverloadDecl.java @@ -93,7 +93,7 @@ public abstract static class Builder { * Sets the parameter types {@link #parameterTypes()}. Note that this will override any * parameter types added via the accumulator methods {@link #addParameterTypes}. */ - public abstract Builder setParameterTypes(ImmutableSet value); + public abstract Builder setParameterTypes(ImmutableList value); public abstract CelType resultType();