From 0b6e8fecbb070f0e1a833cd8034baf328cbe4fa3 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 30 Jul 2024 15:45:22 -0700 Subject: [PATCH] Include identity functions in the standard definitions PiperOrigin-RevId: 657756541 --- .../src/main/java/dev/cel/checker/Env.java | 4 ++ .../main/java/dev/cel/checker/Standard.java | 41 ++++++++++++++++++- .../test/resources/standardEnvDump.baseline | 8 ++++ .../dev/cel/runtime/StandardFunctions.java | 22 ++++++++++ .../test/resources/boolConversions.baseline | 4 ++ .../test/resources/bytesConversions.baseline | 5 +++ .../test/resources/doubleConversions.baseline | 5 +++ .../test/resources/stringConversions.baseline | 5 +++ .../test/resources/timeConversions.baseline | 17 ++++++++ .../test/resources/uint64Conversions.baseline | 10 +++++ .../dev/cel/testing/BaseInterpreterTest.java | 27 ++++++++++++ 11 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 runtime/src/test/resources/boolConversions.baseline diff --git a/checker/src/main/java/dev/cel/checker/Env.java b/checker/src/main/java/dev/cel/checker/Env.java index 8cc354fb2..613bef166 100644 --- a/checker/src/main/java/dev/cel/checker/Env.java +++ b/checker/src/main/java/dev/cel/checker/Env.java @@ -540,6 +540,10 @@ boolean enableTimestampEpoch() { return celOptions.enableTimestampEpoch(); } + boolean enableUnsignedLongs() { + return celOptions.enableUnsignedLongs(); + } + /** Add an identifier {@code decl} to the environment. */ @CanIgnoreReturnValue private Env addIdent(CelIdentDecl celIdentDecl) { diff --git a/checker/src/main/java/dev/cel/checker/Standard.java b/checker/src/main/java/dev/cel/checker/Standard.java index 1781fe559..eaac8c300 100644 --- a/checker/src/main/java/dev/cel/checker/Standard.java +++ b/checker/src/main/java/dev/cel/checker/Standard.java @@ -99,6 +99,14 @@ public static Env add(Env env) { timestampConversionDeclarations(env.enableTimestampEpoch()).forEach(env::add); numericComparisonDeclarations(env.enableHeterogeneousNumericComparisons()).forEach(env::add); + if (env.enableUnsignedLongs()) { + env.add( + CelFunctionDecl.newFunctionDeclaration( + Function.INT.getFunction(), + CelOverloadDecl.newGlobalOverload( + "int64_to_int64", "type conversion (identity)", SimpleType.INT, SimpleType.INT))); + } + return env; } @@ -384,6 +392,8 @@ private static ImmutableList coreFunctionDeclarations() { celFunctionDeclBuilder.add( CelFunctionDecl.newFunctionDeclaration( Function.UINT.getFunction(), + CelOverloadDecl.newGlobalOverload( + "uint64_to_uint64", "type conversion (identity)", SimpleType.UINT, SimpleType.UINT), CelOverloadDecl.newGlobalOverload( "int64_to_uint64", "type conversion", SimpleType.UINT, SimpleType.INT), CelOverloadDecl.newGlobalOverload( @@ -395,6 +405,11 @@ private static ImmutableList coreFunctionDeclarations() { celFunctionDeclBuilder.add( CelFunctionDecl.newFunctionDeclaration( Function.DOUBLE.getFunction(), + CelOverloadDecl.newGlobalOverload( + "double_to_double", + "type conversion (identity)", + SimpleType.DOUBLE, + SimpleType.DOUBLE), CelOverloadDecl.newGlobalOverload( "int64_to_double", "type conversion", SimpleType.DOUBLE, SimpleType.INT), CelOverloadDecl.newGlobalOverload( @@ -406,6 +421,11 @@ private static ImmutableList coreFunctionDeclarations() { celFunctionDeclBuilder.add( CelFunctionDecl.newFunctionDeclaration( Function.STRING.getFunction(), + CelOverloadDecl.newGlobalOverload( + "string_to_string", + "type conversion (identity)", + SimpleType.STRING, + SimpleType.STRING), CelOverloadDecl.newGlobalOverload( "int64_to_string", "type conversion", SimpleType.STRING, SimpleType.INT), CelOverloadDecl.newGlobalOverload( @@ -423,6 +443,8 @@ private static ImmutableList coreFunctionDeclarations() { celFunctionDeclBuilder.add( CelFunctionDecl.newFunctionDeclaration( Function.BYTES.getFunction(), + CelOverloadDecl.newGlobalOverload( + "bytes_to_bytes", "type conversion (identity)", SimpleType.BYTES, SimpleType.BYTES), CelOverloadDecl.newGlobalOverload( "string_to_bytes", "type conversion", SimpleType.BYTES, SimpleType.STRING))); @@ -437,12 +459,24 @@ private static ImmutableList coreFunctionDeclarations() { celFunctionDeclBuilder.add( CelFunctionDecl.newFunctionDeclaration( Function.DURATION.getFunction(), + CelOverloadDecl.newGlobalOverload( + "duration_to_duration", + "type conversion (identity)", + SimpleType.DURATION, + SimpleType.DURATION), CelOverloadDecl.newGlobalOverload( "string_to_duration", "type conversion, duration should be end with \"s\", which stands for seconds", SimpleType.DURATION, SimpleType.STRING))); + // Conversions to boolean + celFunctionDeclBuilder.add( + CelFunctionDecl.newFunctionDeclaration( + Function.BOOL.getFunction(), + CelOverloadDecl.newGlobalOverload( + "bool_to_bool", "type conversion (identity)", SimpleType.BOOL, SimpleType.BOOL))); + // String functions celFunctionDeclBuilder.add( CelFunctionDecl.newFunctionDeclaration( @@ -674,7 +708,12 @@ private static ImmutableList timestampConversionDeclarations(bo "Type conversion of strings to timestamps according to RFC3339. Example:" + " \"1972-01-01T10:00:20.021-05:00\".", SimpleType.TIMESTAMP, - SimpleType.STRING)); + SimpleType.STRING), + CelOverloadDecl.newGlobalOverload( + "timestamp_to_timestamp", + "type conversion (identity)", + SimpleType.TIMESTAMP, + SimpleType.TIMESTAMP)); if (withEpoch) { timestampBuilder.addOverloads( CelOverloadDecl.newGlobalOverload( diff --git a/checker/src/test/resources/standardEnvDump.baseline b/checker/src/test/resources/standardEnvDump.baseline index a49852d84..5f88f7f50 100644 --- a/checker/src/test/resources/standardEnvDump.baseline +++ b/checker/src/test/resources/standardEnvDump.baseline @@ -143,9 +143,11 @@ declare _||_ { } declare bool { value type(bool) + function bool_to_bool (bool) -> bool } declare bytes { value type(bytes) + function bytes_to_bytes (bytes) -> bytes function string_to_bytes (string) -> bytes } declare contains { @@ -153,11 +155,13 @@ declare contains { } declare double { value type(double) + function double_to_double (double) -> double function int64_to_double (int) -> double function uint64_to_double (uint) -> double function string_to_double (string) -> double } declare duration { + function duration_to_duration (google.protobuf.Duration) -> google.protobuf.Duration function string_to_duration (string) -> google.protobuf.Duration } declare dyn { @@ -217,6 +221,7 @@ declare int { function double_to_int64 (double) -> int function string_to_int64 (string) -> int function timestamp_to_int64 (google.protobuf.Timestamp) -> int + function int64_to_int64 (int) -> int } declare list { value type(list(dyn)) @@ -246,6 +251,7 @@ declare startsWith { } declare string { value type(string) + function string_to_string (string) -> string function int64_to_string (int) -> string function uint64_to_string (uint) -> string function double_to_string (double) -> string @@ -255,6 +261,7 @@ declare string { } declare timestamp { function string_to_timestamp (string) -> google.protobuf.Timestamp + function timestamp_to_timestamp (google.protobuf.Timestamp) -> google.protobuf.Timestamp function int64_to_timestamp (int) -> google.protobuf.Timestamp } declare type { @@ -263,6 +270,7 @@ declare type { } declare uint { value type(uint) + function uint64_to_uint64 (uint) -> uint function int64_to_uint64 (int) -> uint function double_to_uint64 (double) -> uint function string_to_uint64 (string) -> uint diff --git a/runtime/src/main/java/dev/cel/runtime/StandardFunctions.java b/runtime/src/main/java/dev/cel/runtime/StandardFunctions.java index d7d00fdaa..23a77c0c7 100644 --- a/runtime/src/main/java/dev/cel/runtime/StandardFunctions.java +++ b/runtime/src/main/java/dev/cel/runtime/StandardFunctions.java @@ -154,6 +154,8 @@ public static void addNonInlined( } private static void addBoolFunctions(Registrar registrar) { + // Identity + registrar.add("bool_to_bool", Boolean.class, (Boolean x) -> x); // The conditional, logical_or, logical_and, and not_strictly_false functions are special-cased. registrar.add("logical_not", Boolean.class, (Boolean x) -> !x); @@ -167,6 +169,8 @@ private static void addBoolFunctions(Registrar registrar) { } private static void addBytesFunctions(Registrar registrar) { + // Identity + registrar.add("bytes_to_bytes", ByteString.class, (ByteString x) -> x); // Bytes ordering functions: <, <=, >=, > registrar.add( "less_bytes", @@ -205,6 +209,8 @@ private static void addBytesFunctions(Registrar registrar) { } private static void addDoubleFunctions(Registrar registrar, CelOptions celOptions) { + // Identity + registrar.add("double_to_double", Double.class, (Double x) -> x); // Double ordering functions. registrar.add("less_double", Double.class, Double.class, (Double x, Double y) -> x < y); registrar.add("less_equals_double", Double.class, Double.class, (Double x, Double y) -> x <= y); @@ -245,6 +251,8 @@ private static void addDoubleFunctions(Registrar registrar, CelOptions celOption } private static void addDurationFunctions(Registrar registrar) { + // Identity + registrar.add("duration_to_duration", Duration.class, (Duration x) -> x); // Duration ordering functions: <, <=, >=, > registrar.add( "less_duration", @@ -307,6 +315,12 @@ private static void addDurationFunctions(Registrar registrar) { } private static void addIntFunctions(Registrar registrar, CelOptions celOptions) { + // Identity + if (celOptions.enableUnsignedLongs()) { + // Note that we require UnsignedLong flag here to avoid ambiguous overloads against + // "uint64_to_int64", because they both use the same Java Long class. + registrar.add("int64_to_int64", Long.class, (Long x) -> x); + } // Comparison functions. registrar.add("less_int64", Long.class, Long.class, (Long x, Long y) -> x < y); registrar.add("less_equals_int64", Long.class, Long.class, (Long x, Long y) -> x <= y); @@ -499,6 +513,8 @@ private static void addMapFunctions( } private static void addStringFunctions(Registrar registrar, CelOptions celOptions) { + // Identity + registrar.add("string_to_string", String.class, (String x) -> x); // String ordering functions: <, <=, >=, >. registrar.add( "less_string", String.class, String.class, (String x, String y) -> x.compareTo(y) < 0); @@ -547,6 +563,8 @@ private static void addStringFunctions(Registrar registrar, CelOptions celOption // timestamp_to_milliseconds overload @SuppressWarnings("JavaLocalDateTimeGetNano") private static void addTimestampFunctions(Registrar registrar) { + // Identity + registrar.add("timestamp_to_timestamp", Timestamp.class, (Timestamp x) -> x); // Timestamp relation operators: <, <=, >=, > registrar.add( "less_timestamp", @@ -714,6 +732,8 @@ private static void addTimestampFunctions(Registrar registrar) { } private static void addSignedUintFunctions(Registrar registrar, CelOptions celOptions) { + // Identity + registrar.add("uint64_to_uint64", Long.class, (Long x) -> x); // Uint relation operators: <, <=, >=, > registrar.add( "less_uint64", @@ -833,6 +853,8 @@ private static void addSignedUintFunctions(Registrar registrar, CelOptions celOp } private static void addUintFunctions(Registrar registrar, CelOptions celOptions) { + // Identity + registrar.add("uint64_to_uint64", UnsignedLong.class, (UnsignedLong x) -> x); registrar.add( "less_uint64", UnsignedLong.class, diff --git a/runtime/src/test/resources/boolConversions.baseline b/runtime/src/test/resources/boolConversions.baseline new file mode 100644 index 000000000..3d7415aba --- /dev/null +++ b/runtime/src/test/resources/boolConversions.baseline @@ -0,0 +1,4 @@ +Source: bool(true) +=====> +bindings: {} +result: true \ No newline at end of file diff --git a/runtime/src/test/resources/bytesConversions.baseline b/runtime/src/test/resources/bytesConversions.baseline index 4c1458268..6324929ce 100644 --- a/runtime/src/test/resources/bytesConversions.baseline +++ b/runtime/src/test/resources/bytesConversions.baseline @@ -2,3 +2,8 @@ Source: bytes('abc\303') =====> bindings: {} result: abcà + +Source: bytes(bytes('abc\303')) +=====> +bindings: {} +result: abcà \ No newline at end of file diff --git a/runtime/src/test/resources/doubleConversions.baseline b/runtime/src/test/resources/doubleConversions.baseline index d15dfe090..682e56edc 100644 --- a/runtime/src/test/resources/doubleConversions.baseline +++ b/runtime/src/test/resources/doubleConversions.baseline @@ -18,3 +18,8 @@ Source: double('bad') bindings: {} error: evaluation error: For input string: "bad" error_code: BAD_FORMAT + +Source: double(1.5) +=====> +bindings: {} +result: 1.5 \ No newline at end of file diff --git a/runtime/src/test/resources/stringConversions.baseline b/runtime/src/test/resources/stringConversions.baseline index 7e20cff19..59406fc85 100644 --- a/runtime/src/test/resources/stringConversions.baseline +++ b/runtime/src/test/resources/stringConversions.baseline @@ -32,3 +32,8 @@ Source: string(duration('1000000s')) =====> bindings: {} result: 1000000s + +Source: string('hello') +=====> +bindings: {} +result: hello \ No newline at end of file diff --git a/runtime/src/test/resources/timeConversions.baseline b/runtime/src/test/resources/timeConversions.baseline index 18b8fdcb0..cb1c4e886 100644 --- a/runtime/src/test/resources/timeConversions.baseline +++ b/runtime/src/test/resources/timeConversions.baseline @@ -54,3 +54,20 @@ declare t1 { bindings: {} error: evaluation error: invalid duration format error_code: BAD_FORMAT + +Source: duration(duration('15.0s')) +declare t1 { + value google.protobuf.Timestamp +} +=====> +bindings: {} +result: seconds: 15 + + +Source: timestamp(timestamp(123)) +declare t1 { + value google.protobuf.Timestamp +} +=====> +bindings: {} +result: seconds: 123 \ No newline at end of file diff --git a/runtime/src/test/resources/uint64Conversions.baseline b/runtime/src/test/resources/uint64Conversions.baseline index cfdc61974..fac9bfc52 100644 --- a/runtime/src/test/resources/uint64Conversions.baseline +++ b/runtime/src/test/resources/uint64Conversions.baseline @@ -35,3 +35,13 @@ Source: uint('f1') bindings: {} error: evaluation error: f1 error_code: BAD_FORMAT + +Source: uint(1u) +=====> +bindings: {} +result: 1 + +Source: uint(dyn(1u)) +=====> +bindings: {} +result: 1 diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 795145508..3b9d63f11 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -1261,6 +1261,12 @@ public void timeConversions() throws Exception { // Not supported. source = "duration('inf')"; runTest(Activation.EMPTY); + + source = "duration(duration('15.0s'))"; // Identity + runTest(Activation.EMPTY); + + source = "timestamp(timestamp(123))"; // Identity + runTest(Activation.EMPTY); } @Test @@ -1491,6 +1497,12 @@ public void uint64Conversions() throws Exception { source = "uint('f1')"; // should error runTest(Activation.EMPTY); + + source = "uint(1u)"; // identity + runTest(Activation.EMPTY); + + source = "uint(dyn(1u))"; // identity, check dynamic dispatch + runTest(Activation.EMPTY); } @Test @@ -1506,6 +1518,9 @@ public void doubleConversions() throws Exception { source = "double('bad')"; runTest(Activation.EMPTY); + + source = "double(1.5)"; // Identity + runTest(Activation.EMPTY); } @Test @@ -1536,6 +1551,9 @@ public void stringConversions() throws Exception { source = "string(duration('1000000s'))"; runTest(Activation.EMPTY); + + source = "string('hello')"; // Identity + runTest(Activation.EMPTY); } @Test @@ -1546,10 +1564,19 @@ public void bytes() throws Exception { runTest(Activation.EMPTY); } + @Test + public void boolConversions() throws Exception { + source = "bool(true)"; + runTest(Activation.EMPTY); // Identity + } + @Test public void bytesConversions() throws Exception { source = "bytes('abc\\303')"; runTest(Activation.EMPTY); // string converts to abcà in bytes form. + + source = "bytes(bytes('abc\\303'))"; // Identity + runTest(Activation.EMPTY); } @Test