diff --git a/common/types/int.go b/common/types/int.go index 940772ae..0ae9507c 100644 --- a/common/types/int.go +++ b/common/types/int.go @@ -90,6 +90,18 @@ func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) { return nil, err } return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil + case reflect.Int8: + v, err := int64ToInt8Checked(int64(i)) + if err != nil { + return nil, err + } + return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil + case reflect.Int16: + v, err := int64ToInt16Checked(int64(i)) + if err != nil { + return nil, err + } + return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil case reflect.Int64: return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil case reflect.Ptr: diff --git a/common/types/int_test.go b/common/types/int_test.go index 9f76bb77..6529843c 100644 --- a/common/types/int_test.go +++ b/common/types/int_test.go @@ -165,6 +165,36 @@ func TestIntConvertToNative_Error(t *testing.T) { } } +func TestIntConvertToNative_Int8(t *testing.T) { + val, err := Int(127).ConvertToNative(reflect.TypeOf(int8(0))) + if err != nil { + t.Fatalf("Int.ConvertToNative(int8) failed: %v", err) + } + if val.(int8) != 127 { + t.Errorf("Got '%v', expected 20050", val) + } + val, err = Int(math.MaxInt8 + 1).ConvertToNative(reflect.TypeOf(int8(0))) + if err == nil { + t.Errorf("(MaxInt+1).ConvertToNative(int8) did not error, got: %v", val) + } else if !strings.Contains(err.Error(), "integer overflow") { + t.Errorf("ConvertToNative(int8) returned unexpected error: %v, wanted integer overflow", err) + } +} +func TestIntConvertToNative_Int16(t *testing.T) { + val, err := Int(20050).ConvertToNative(reflect.TypeOf(int16(0))) + if err != nil { + t.Fatalf("Int.ConvertToNative(int16) failed: %v", err) + } + if val.(int16) != 20050 { + t.Errorf("Got '%v', expected 20050", val) + } + val, err = Int(math.MaxInt16 + 1).ConvertToNative(reflect.TypeOf(int16(0))) + if err == nil { + t.Errorf("(MaxInt+1).ConvertToNative(int16) did not error, got: %v", val) + } else if !strings.Contains(err.Error(), "integer overflow") { + t.Errorf("ConvertToNative(int32) returned unexpected error: %v, wanted integer overflow", err) + } +} func TestIntConvertToNative_Int32(t *testing.T) { val, err := Int(20050).ConvertToNative(reflect.TypeOf(int32(0))) if err != nil { diff --git a/common/types/overflow.go b/common/types/overflow.go index c68a9218..dcb66ef5 100644 --- a/common/types/overflow.go +++ b/common/types/overflow.go @@ -326,6 +326,26 @@ func int64ToUint64Checked(v int64) (uint64, error) { return uint64(v), nil } +// int64ToInt8Checked converts an int64 to an int8 value. +// +// If the conversion fails due to overflow the error return value will be non-nil. +func int64ToInt8Checked(v int64) (int8, error) { + if v < math.MinInt8 || v > math.MaxInt8 { + return 0, errIntOverflow + } + return int8(v), nil +} + +// int64ToInt16Checked converts an int64 to an int16 value. +// +// If the conversion fails due to overflow the error return value will be non-nil. +func int64ToInt16Checked(v int64) (int16, error) { + if v < math.MinInt16 || v > math.MaxInt16 { + return 0, errIntOverflow + } + return int16(v), nil +} + // int64ToInt32Checked converts an int64 to an int32 value. // // If the conversion fails due to overflow the error return value will be non-nil. @@ -336,6 +356,26 @@ func int64ToInt32Checked(v int64) (int32, error) { return int32(v), nil } +// uint64ToUint8Checked converts a uint64 to a uint8 value. +// +// If the conversion fails due to overflow the error return value will be non-nil. +func uint64ToUint8Checked(v uint64) (uint8, error) { + if v > math.MaxUint8 { + return 0, errUintOverflow + } + return uint8(v), nil +} + +// uint64ToUint16Checked converts a uint64 to a uint16 value. +// +// If the conversion fails due to overflow the error return value will be non-nil. +func uint64ToUint16Checked(v uint64) (uint16, error) { + if v > math.MaxUint16 { + return 0, errUintOverflow + } + return uint16(v), nil +} + // uint64ToUint32Checked converts a uint64 to a uint32 value. // // If the conversion fails due to overflow the error return value will be non-nil. diff --git a/common/types/provider.go b/common/types/provider.go index 5157cd1f..c5ff05fd 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -590,12 +590,33 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) { return NewDynamicMap(a, v), true // type aliases of primitive types cannot be asserted as that type, but rather need // to be downcast to int32 before being converted to a CEL representation. + case reflect.Bool: + boolTupe := reflect.TypeOf(false) + return Bool(refValue.Convert(boolTupe).Interface().(bool)), true + case reflect.Int: + intType := reflect.TypeOf(int(0)) + return Int(refValue.Convert(intType).Interface().(int)), true + case reflect.Int8: + intType := reflect.TypeOf(int8(0)) + return Int(refValue.Convert(intType).Interface().(int8)), true + case reflect.Int16: + intType := reflect.TypeOf(int16(0)) + return Int(refValue.Convert(intType).Interface().(int16)), true case reflect.Int32: intType := reflect.TypeOf(int32(0)) return Int(refValue.Convert(intType).Interface().(int32)), true case reflect.Int64: intType := reflect.TypeOf(int64(0)) return Int(refValue.Convert(intType).Interface().(int64)), true + case reflect.Uint: + uintType := reflect.TypeOf(uint(0)) + return Uint(refValue.Convert(uintType).Interface().(uint)), true + case reflect.Uint8: + uintType := reflect.TypeOf(uint8(0)) + return Uint(refValue.Convert(uintType).Interface().(uint8)), true + case reflect.Uint16: + uintType := reflect.TypeOf(uint16(0)) + return Uint(refValue.Convert(uintType).Interface().(uint16)), true case reflect.Uint32: uintType := reflect.TypeOf(uint32(0)) return Uint(refValue.Convert(uintType).Interface().(uint32)), true @@ -608,6 +629,9 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) { case reflect.Float64: doubleType := reflect.TypeOf(float64(0)) return Double(refValue.Convert(doubleType).Interface().(float64)), true + case reflect.String: + stringType := reflect.TypeOf("") + return String(refValue.Convert(stringType).Interface().(string)), true } } return nil, false diff --git a/common/types/provider_test.go b/common/types/provider_test.go index 56f15290..efe1244a 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -566,6 +566,21 @@ func TestConvertToNative(t *testing.T) { // Proto conversion tests. parsedExpr := &exprpb.ParsedExpr{} expectValueToNative(t, reg.NativeToValue(parsedExpr), parsedExpr) + + // Custom scalars + expectValueToNative(t, Int(1), testInt(1)) + expectValueToNative(t, Int(1), testInt8(1)) + expectValueToNative(t, Int(1), testInt16(1)) + expectValueToNative(t, Int(1), testInt32(1)) + expectValueToNative(t, Int(1), testInt64(1)) + expectValueToNative(t, Uint(1), testUint(1)) + expectValueToNative(t, Uint(1), testUint8(1)) + expectValueToNative(t, Uint(1), testUint16(1)) + expectValueToNative(t, Uint(1), testUint32(1)) + expectValueToNative(t, Uint(1), testUint64(1)) + expectValueToNative(t, Double(4.5), testFloat32(4.5)) + expectValueToNative(t, Double(-5.1), testFloat64(-5.1)) + expectValueToNative(t, String("foo"), testString("foo")) } func TestNativeToValue_Any(t *testing.T) { @@ -758,12 +773,19 @@ func TestNativeToValue_Primitive(t *testing.T) { expectNativeToValue(t, &rBytes, rBytes) // Extensions to core types. + expectNativeToValue(t, testInt(1), Int(1)) + expectNativeToValue(t, testInt8(1), Int(1)) + expectNativeToValue(t, testInt16(1), Int(1)) expectNativeToValue(t, testInt32(1), Int(1)) expectNativeToValue(t, testInt64(-100), Int(-100)) + expectNativeToValue(t, testUint(1), Uint(1)) + expectNativeToValue(t, testUint8(1), Uint(1)) + expectNativeToValue(t, testUint16(1), Uint(1)) expectNativeToValue(t, testUint32(2), Uint(2)) expectNativeToValue(t, testUint64(3), Uint(3)) expectNativeToValue(t, testFloat32(4.5), Double(4.5)) expectNativeToValue(t, testFloat64(-5.1), Double(-5.1)) + expectNativeToValue(t, testString("foo"), String("foo")) // Null conversion test. expectNativeToValue(t, nil, NullValue) @@ -795,7 +817,7 @@ func expectValueToNative(t *testing.T, in ref.Val, out any) { } if !equals { t.Errorf("Unexpected conversion from expr to proto.\n"+ - "expected: %T, actual: %T", val, out) + "expected: %T, actual: %T", out, val) } } } @@ -870,12 +892,20 @@ func BenchmarkTypeProviderCopy(b *testing.B) { type nonConvertible struct { Field string } +type testBool bool +type testInt int +type testInt8 int8 +type testInt16 int16 type testInt32 int32 type testInt64 int64 +type testUint uint +type testUint8 uint8 +type testUint16 uint16 type testUint32 uint32 type testUint64 uint64 type testFloat32 float32 type testFloat64 float64 +type testString string func newTestRegistry(t *testing.T, types ...proto.Message) *Registry { t.Helper() diff --git a/common/types/string.go b/common/types/string.go index a2990b26..3a93743f 100644 --- a/common/types/string.go +++ b/common/types/string.go @@ -66,10 +66,7 @@ func (s String) Compare(other ref.Val) ref.Val { func (s String) ConvertToNative(typeDesc reflect.Type) (any, error) { switch typeDesc.Kind() { case reflect.String: - if reflect.TypeOf(s).AssignableTo(typeDesc) { - return s, nil - } - return s.Value(), nil + return reflect.ValueOf(s).Convert(typeDesc).Interface(), nil case reflect.Ptr: switch typeDesc { case anyValueType: diff --git a/common/types/string_test.go b/common/types/string_test.go index 226b1932..37958535 100644 --- a/common/types/string_test.go +++ b/common/types/string_test.go @@ -106,6 +106,17 @@ func TestStringConvertToNative_String(t *testing.T) { } } +type customString string + +func TestStringConvertToNative_CustomString(t *testing.T) { + val, err := String("hello").ConvertToNative(reflect.TypeOf(customString(""))) + if err != nil { + t.Error(err) + } else if v, ok := val.(customString); !ok || v != "hello" { + t.Errorf("Got %T with val '%v', expected %T with val 'hello'", val, v, customString("")) + } +} + func TestStringConvertToNative_Wrapper(t *testing.T) { val, err := String("hello").ConvertToNative(stringWrapperType) if err != nil { diff --git a/common/types/uint.go b/common/types/uint.go index 3257f9ad..6d74f30d 100644 --- a/common/types/uint.go +++ b/common/types/uint.go @@ -80,6 +80,18 @@ func (i Uint) ConvertToNative(typeDesc reflect.Type) (any, error) { return 0, err } return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil + case reflect.Uint8: + v, err := uint64ToUint8Checked(uint64(i)) + if err != nil { + return 0, err + } + return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil + case reflect.Uint16: + v, err := uint64ToUint16Checked(uint64(i)) + if err != nil { + return 0, err + } + return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil case reflect.Uint64: return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil case reflect.Ptr: diff --git a/common/types/uint_test.go b/common/types/uint_test.go index 777d7955..f07832ca 100644 --- a/common/types/uint_test.go +++ b/common/types/uint_test.go @@ -172,6 +172,38 @@ func TestUintConvertToNative_Json(t *testing.T) { } } +func TestUintConvertToNative_Uint8(t *testing.T) { + val, err := Uint(128).ConvertToNative(reflect.TypeOf(uint8(0))) + if err != nil { + t.Fatalf("Uint.ConvertToNative(uint8) failed: %v", err) + } + if val.(uint8) != 128 { + t.Errorf("Got '%v', expected 128", val) + } + val, err = Uint(math.MaxUint8 + 1).ConvertToNative(reflect.TypeOf(uint8(0))) + if err == nil { + t.Errorf("(MaxUint+1).ConvertToNative(uint8) did not error, got: %v", val) + } else if !strings.Contains(err.Error(), "unsigned integer overflow") { + t.Errorf("ConvertToNative(uint8) returned unexpected error: %v, wanted unsigned integer overflow", err) + } +} + +func TestUintConvertToNative_Uint16(t *testing.T) { + val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint16(0))) + if err != nil { + t.Fatalf("Uint.ConvertToNative(uint16) failed: %v", err) + } + if val.(uint16) != 20050 { + t.Errorf("Got '%v', expected 20050", val) + } + val, err = Uint(math.MaxUint16 + 1).ConvertToNative(reflect.TypeOf(uint16(0))) + if err == nil { + t.Errorf("(MaxUint+1).ConvertToNative(uint16) did not error, got: %v", val) + } else if !strings.Contains(err.Error(), "unsigned integer overflow") { + t.Errorf("ConvertToNative(uint16) returned unexpected error: %v, wanted unsigned integer overflow", err) + } +} + func TestUintConvertToNative_Uint32(t *testing.T) { val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint32(0))) if err != nil {