Skip to content

Commit

Permalink
DefaultTypeAdapter: Add support for missing custom scalars (#893)
Browse files Browse the repository at this point in the history
The default type adapter already supports some custom scalar types but
not all, this change adds the missing ones. The most only used one is
likely string.
  • Loading branch information
alvaroaleman authored Feb 2, 2024
1 parent 5883379 commit 19b2ad1
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 5 deletions.
12 changes: 12 additions & 0 deletions common/types/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int8:
v, err := int64ToInt8Checked(int64(i))
if err != nil {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int16:
v, err := int64ToInt16Checked(int64(i))
if err != nil {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int64:
return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
case reflect.Ptr:
Expand Down
30 changes: 30 additions & 0 deletions common/types/int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,36 @@ func TestIntConvertToNative_Error(t *testing.T) {
}
}

func TestIntConvertToNative_Int8(t *testing.T) {
val, err := Int(127).ConvertToNative(reflect.TypeOf(int8(0)))
if err != nil {
t.Fatalf("Int.ConvertToNative(int8) failed: %v", err)
}
if val.(int8) != 127 {
t.Errorf("Got '%v', expected 20050", val)
}
val, err = Int(math.MaxInt8 + 1).ConvertToNative(reflect.TypeOf(int8(0)))
if err == nil {
t.Errorf("(MaxInt+1).ConvertToNative(int8) did not error, got: %v", val)
} else if !strings.Contains(err.Error(), "integer overflow") {
t.Errorf("ConvertToNative(int8) returned unexpected error: %v, wanted integer overflow", err)
}
}
func TestIntConvertToNative_Int16(t *testing.T) {
val, err := Int(20050).ConvertToNative(reflect.TypeOf(int16(0)))
if err != nil {
t.Fatalf("Int.ConvertToNative(int16) failed: %v", err)
}
if val.(int16) != 20050 {
t.Errorf("Got '%v', expected 20050", val)
}
val, err = Int(math.MaxInt16 + 1).ConvertToNative(reflect.TypeOf(int16(0)))
if err == nil {
t.Errorf("(MaxInt+1).ConvertToNative(int16) did not error, got: %v", val)
} else if !strings.Contains(err.Error(), "integer overflow") {
t.Errorf("ConvertToNative(int32) returned unexpected error: %v, wanted integer overflow", err)
}
}
func TestIntConvertToNative_Int32(t *testing.T) {
val, err := Int(20050).ConvertToNative(reflect.TypeOf(int32(0)))
if err != nil {
Expand Down
40 changes: 40 additions & 0 deletions common/types/overflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,26 @@ func int64ToUint64Checked(v int64) (uint64, error) {
return uint64(v), nil
}

// int64ToInt8Checked converts an int64 to an int8 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func int64ToInt8Checked(v int64) (int8, error) {
if v < math.MinInt8 || v > math.MaxInt8 {
return 0, errIntOverflow
}
return int8(v), nil
}

// int64ToInt16Checked converts an int64 to an int16 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func int64ToInt16Checked(v int64) (int16, error) {
if v < math.MinInt16 || v > math.MaxInt16 {
return 0, errIntOverflow
}
return int16(v), nil
}

// int64ToInt32Checked converts an int64 to an int32 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
Expand All @@ -336,6 +356,26 @@ func int64ToInt32Checked(v int64) (int32, error) {
return int32(v), nil
}

// uint64ToUint8Checked converts a uint64 to a uint8 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func uint64ToUint8Checked(v uint64) (uint8, error) {
if v > math.MaxUint8 {
return 0, errUintOverflow
}
return uint8(v), nil
}

// uint64ToUint16Checked converts a uint64 to a uint16 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func uint64ToUint16Checked(v uint64) (uint16, error) {
if v > math.MaxUint16 {
return 0, errUintOverflow
}
return uint16(v), nil
}

// uint64ToUint32Checked converts a uint64 to a uint32 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
Expand Down
24 changes: 24 additions & 0 deletions common/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,33 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
return NewDynamicMap(a, v), true
// type aliases of primitive types cannot be asserted as that type, but rather need
// to be downcast to int32 before being converted to a CEL representation.
case reflect.Bool:
boolTupe := reflect.TypeOf(false)
return Bool(refValue.Convert(boolTupe).Interface().(bool)), true
case reflect.Int:
intType := reflect.TypeOf(int(0))
return Int(refValue.Convert(intType).Interface().(int)), true
case reflect.Int8:
intType := reflect.TypeOf(int8(0))
return Int(refValue.Convert(intType).Interface().(int8)), true
case reflect.Int16:
intType := reflect.TypeOf(int16(0))
return Int(refValue.Convert(intType).Interface().(int16)), true
case reflect.Int32:
intType := reflect.TypeOf(int32(0))
return Int(refValue.Convert(intType).Interface().(int32)), true
case reflect.Int64:
intType := reflect.TypeOf(int64(0))
return Int(refValue.Convert(intType).Interface().(int64)), true
case reflect.Uint:
uintType := reflect.TypeOf(uint(0))
return Uint(refValue.Convert(uintType).Interface().(uint)), true
case reflect.Uint8:
uintType := reflect.TypeOf(uint8(0))
return Uint(refValue.Convert(uintType).Interface().(uint8)), true
case reflect.Uint16:
uintType := reflect.TypeOf(uint16(0))
return Uint(refValue.Convert(uintType).Interface().(uint16)), true
case reflect.Uint32:
uintType := reflect.TypeOf(uint32(0))
return Uint(refValue.Convert(uintType).Interface().(uint32)), true
Expand All @@ -608,6 +629,9 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
case reflect.Float64:
doubleType := reflect.TypeOf(float64(0))
return Double(refValue.Convert(doubleType).Interface().(float64)), true
case reflect.String:
stringType := reflect.TypeOf("")
return String(refValue.Convert(stringType).Interface().(string)), true
}
}
return nil, false
Expand Down
32 changes: 31 additions & 1 deletion common/types/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,21 @@ func TestConvertToNative(t *testing.T) {
// Proto conversion tests.
parsedExpr := &exprpb.ParsedExpr{}
expectValueToNative(t, reg.NativeToValue(parsedExpr), parsedExpr)

// Custom scalars
expectValueToNative(t, Int(1), testInt(1))
expectValueToNative(t, Int(1), testInt8(1))
expectValueToNative(t, Int(1), testInt16(1))
expectValueToNative(t, Int(1), testInt32(1))
expectValueToNative(t, Int(1), testInt64(1))
expectValueToNative(t, Uint(1), testUint(1))
expectValueToNative(t, Uint(1), testUint8(1))
expectValueToNative(t, Uint(1), testUint16(1))
expectValueToNative(t, Uint(1), testUint32(1))
expectValueToNative(t, Uint(1), testUint64(1))
expectValueToNative(t, Double(4.5), testFloat32(4.5))
expectValueToNative(t, Double(-5.1), testFloat64(-5.1))
expectValueToNative(t, String("foo"), testString("foo"))
}

func TestNativeToValue_Any(t *testing.T) {
Expand Down Expand Up @@ -758,12 +773,19 @@ func TestNativeToValue_Primitive(t *testing.T) {
expectNativeToValue(t, &rBytes, rBytes)

// Extensions to core types.
expectNativeToValue(t, testInt(1), Int(1))
expectNativeToValue(t, testInt8(1), Int(1))
expectNativeToValue(t, testInt16(1), Int(1))
expectNativeToValue(t, testInt32(1), Int(1))
expectNativeToValue(t, testInt64(-100), Int(-100))
expectNativeToValue(t, testUint(1), Uint(1))
expectNativeToValue(t, testUint8(1), Uint(1))
expectNativeToValue(t, testUint16(1), Uint(1))
expectNativeToValue(t, testUint32(2), Uint(2))
expectNativeToValue(t, testUint64(3), Uint(3))
expectNativeToValue(t, testFloat32(4.5), Double(4.5))
expectNativeToValue(t, testFloat64(-5.1), Double(-5.1))
expectNativeToValue(t, testString("foo"), String("foo"))

// Null conversion test.
expectNativeToValue(t, nil, NullValue)
Expand Down Expand Up @@ -795,7 +817,7 @@ func expectValueToNative(t *testing.T, in ref.Val, out any) {
}
if !equals {
t.Errorf("Unexpected conversion from expr to proto.\n"+
"expected: %T, actual: %T", val, out)
"expected: %T, actual: %T", out, val)
}
}
}
Expand Down Expand Up @@ -870,12 +892,20 @@ func BenchmarkTypeProviderCopy(b *testing.B) {
type nonConvertible struct {
Field string
}
type testBool bool
type testInt int
type testInt8 int8
type testInt16 int16
type testInt32 int32
type testInt64 int64
type testUint uint
type testUint8 uint8
type testUint16 uint16
type testUint32 uint32
type testUint64 uint64
type testFloat32 float32
type testFloat64 float64
type testString string

func newTestRegistry(t *testing.T, types ...proto.Message) *Registry {
t.Helper()
Expand Down
5 changes: 1 addition & 4 deletions common/types/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ func (s String) Compare(other ref.Val) ref.Val {
func (s String) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.String:
if reflect.TypeOf(s).AssignableTo(typeDesc) {
return s, nil
}
return s.Value(), nil
return reflect.ValueOf(s).Convert(typeDesc).Interface(), nil
case reflect.Ptr:
switch typeDesc {
case anyValueType:
Expand Down
11 changes: 11 additions & 0 deletions common/types/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ func TestStringConvertToNative_String(t *testing.T) {
}
}

type customString string

func TestStringConvertToNative_CustomString(t *testing.T) {
val, err := String("hello").ConvertToNative(reflect.TypeOf(customString("")))
if err != nil {
t.Error(err)
} else if v, ok := val.(customString); !ok || v != "hello" {
t.Errorf("Got %T with val '%v', expected %T with val 'hello'", val, v, customString(""))
}
}

func TestStringConvertToNative_Wrapper(t *testing.T) {
val, err := String("hello").ConvertToNative(stringWrapperType)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions common/types/uint.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ func (i Uint) ConvertToNative(typeDesc reflect.Type) (any, error) {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint8:
v, err := uint64ToUint8Checked(uint64(i))
if err != nil {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint16:
v, err := uint64ToUint16Checked(uint64(i))
if err != nil {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint64:
return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
case reflect.Ptr:
Expand Down
32 changes: 32 additions & 0 deletions common/types/uint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,38 @@ func TestUintConvertToNative_Json(t *testing.T) {
}
}

func TestUintConvertToNative_Uint8(t *testing.T) {
val, err := Uint(128).ConvertToNative(reflect.TypeOf(uint8(0)))
if err != nil {
t.Fatalf("Uint.ConvertToNative(uint8) failed: %v", err)
}
if val.(uint8) != 128 {
t.Errorf("Got '%v', expected 128", val)
}
val, err = Uint(math.MaxUint8 + 1).ConvertToNative(reflect.TypeOf(uint8(0)))
if err == nil {
t.Errorf("(MaxUint+1).ConvertToNative(uint8) did not error, got: %v", val)
} else if !strings.Contains(err.Error(), "unsigned integer overflow") {
t.Errorf("ConvertToNative(uint8) returned unexpected error: %v, wanted unsigned integer overflow", err)
}
}

func TestUintConvertToNative_Uint16(t *testing.T) {
val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint16(0)))
if err != nil {
t.Fatalf("Uint.ConvertToNative(uint16) failed: %v", err)
}
if val.(uint16) != 20050 {
t.Errorf("Got '%v', expected 20050", val)
}
val, err = Uint(math.MaxUint16 + 1).ConvertToNative(reflect.TypeOf(uint16(0)))
if err == nil {
t.Errorf("(MaxUint+1).ConvertToNative(uint16) did not error, got: %v", val)
} else if !strings.Contains(err.Error(), "unsigned integer overflow") {
t.Errorf("ConvertToNative(uint16) returned unexpected error: %v, wanted unsigned integer overflow", err)
}
}

func TestUintConvertToNative_Uint32(t *testing.T) {
val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint32(0)))
if err != nil {
Expand Down

0 comments on commit 19b2ad1

Please sign in to comment.