Skip to content

Commit

Permalink
Fix bug with fallback on global registered functions
Browse files Browse the repository at this point in the history
  • Loading branch information
roeldev committed Sep 26, 2023
1 parent e88cdef commit a68f7d4
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 13 deletions.
13 changes: 7 additions & 6 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type UnmarshalFunc func(val Value, dest any) error
// RegisterUnmarshalFunc.
func GetUnmarshalFunc(typ reflect.Type) UnmarshalFunc { return unmarshaler.Func(typ) }

// unmarshaler is the global root Unmarshaler.
// unmarshaler is the global Unmarshaler.
var unmarshaler Unmarshaler

// Unmarshaler is a type which can unmarshal a Value to any type that's
Expand All @@ -70,12 +70,13 @@ func (u *Unmarshaler) Register(typ reflect.Type, fn UnmarshalFunc) *Unmarshaler
// Func returns the (globally) registered UnmarshalFunc for reflect.Type typ or
// nil if there is none registered with Register or RegisterUnmarshalFunc.
func (u *Unmarshaler) Func(typ reflect.Type) UnmarshalFunc {
if !u.register.initialized() {
// unmarshaler is always initialized
return unmarshaler.Func(typ)
if u.register.initialized() {
if fn := u.register.find(typ); fn != nil {
return fn
}
}

return u.register.find(typ)
// fallback to global unmarshaler
return unmarshaler.register.find(typ)
}

// Unmarshal tries to unmarshal Value to a supported type which matches the
Expand Down
8 changes: 8 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ func TestUnmarshal(t *testing.T) {
}
}

func TestUnmarshaler_Func(t *testing.T) {
var u Unmarshaler
u.Register(reflect.TypeOf(t), func(Value, any) error {
return nil
})
testRegisterFind(t, 0, func(typ reflect.Type) any { return u.Func(typ) })
}

func TestUnmarshaler_Unmarshal(t *testing.T) {
timeVal, _ := time.Parse(time.RFC3339, "1997-08-29T13:37:00Z")
urlPtr, _ := url.ParseRequestURI("http://localhost/")
Expand Down
12 changes: 7 additions & 5 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type MarshalFunc func(v any) (string, error)
// typ or nil if there is none registered with RegisterMarshalFunc.
func GetMarshalFunc(typ reflect.Type) MarshalFunc { return marshaler.Func(typ) }

// marshaler is the global Marshaler.
var marshaler Marshaler

// Marshaler is a type which can marshal any reflect.Value to its raw string
Expand All @@ -51,12 +52,13 @@ func (m *Marshaler) Register(typ reflect.Type, fn MarshalFunc) *Marshaler {
// Func returns the (globally) registered MarshalFunc for reflect.Type typ or
// nil if there is none registered with Register or RegisterMarshalFunc.
func (m *Marshaler) Func(typ reflect.Type) MarshalFunc {
if !m.register.initialized() {
// marshaler is always initialized
return marshaler.Func(typ)
if m.register.initialized() {
if fn := m.register.find(typ); fn != nil {
return fn
}
}

return m.register.find(typ)
// fallback to global marshaler
return marshaler.register.find(typ)
}

// Marshal returns the string representation of the value.
Expand Down
8 changes: 8 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ func TestMarshal(t *testing.T) {
}
}

func TestMarshaler_Func(t *testing.T) {
var m Marshaler
m.Register(reflect.TypeOf(t), func(any) (string, error) {
return "", nil
})
testRegisterFind(t, 1, func(typ reflect.Type) any { return m.Func(typ) })
}

func TestMarshalFunc_Exec(t *testing.T) {
wantErr := errors.New("some err")
_, haveErr := MarshalFunc(func(v any) (string, error) {
Expand Down
4 changes: 2 additions & 2 deletions register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestGetMarshalFunc(t *testing.T) {
testRegisterFind(t, 1, func(typ reflect.Type) any { return GetMarshalFunc(typ) })
}

func testRegisterFind(t *testing.T, i int, haveFn func(reflect.Type) any) {
func testRegisterFind(t *testing.T, i int, getFn func(reflect.Type) any) {
tests := []struct {
want [2]uintptr
types []reflect.Type
Expand Down Expand Up @@ -60,7 +60,7 @@ func testRegisterFind(t *testing.T, i int, haveFn func(reflect.Type) any) {
for _, tc := range tests {
for _, typ := range tc.types {
t.Run(typ.String(), func(t *testing.T) {
have := haveFn(typ)
have := getFn(typ)
assert.Equal(t, tc.want[i], reflect.ValueOf(have).Pointer())
})
}
Expand Down

0 comments on commit a68f7d4

Please sign in to comment.