Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: httpx.Parse supports parsing structures that implement the Unmarshaler interface #4143

Merged
merged 1 commit into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 59 additions & 21 deletions core/mapping/unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@
}

func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value,
mapValue any, fullName string) error {
mapValue any, fullName string,
) error {
var slice []any
switch v := mapValue.(type) {
case fmt.Stringer:
Expand Down Expand Up @@ -248,7 +249,8 @@
}

func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
baseKind reflect.Kind, value any, fullName string) error {
baseKind reflect.Kind, value any, fullName string,
) error {
if value == nil {
return errNilSliceElement
}
Expand Down Expand Up @@ -286,7 +288,8 @@
}

func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value,
defaultValue, fullName string) error {
defaultValue, fullName string,
) error {
baseFieldType := Deref(derefedType.Elem())
baseFieldKind := baseFieldType.Kind()
defaultCacheLock.Lock()
Expand Down Expand Up @@ -400,7 +403,8 @@
}

func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Valuer, fullName string) (
string, *fieldOptionsWithContext, error) {
string, *fieldOptionsWithContext, error,
) {
key, options, err := parseKeyAndOptions(u.key, field)
if err != nil {
return "", nil, err
Expand Down Expand Up @@ -441,7 +445,8 @@
}

func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
key, options, err := u.parseOptionsWithContext(field, m, fullName)
if err != nil {
return err
Expand All @@ -459,7 +464,8 @@
}

func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value,
key string, m valuerWithParent, fullName string) error {
key string, m valuerWithParent, fullName string,
) error {
derefedFieldType := Deref(field.Type)

switch derefedFieldType.Kind() {
Expand All @@ -471,7 +477,8 @@
}

func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
fieldType := field.Type
maybeNewValue(fieldType, value)
derefedFieldType := Deref(fieldType)
Expand All @@ -495,7 +502,8 @@
}

func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type,
value reflect.Value, key string, m valuerWithParent, fullName string) error {
value reflect.Value, key string, m valuerWithParent, fullName string,
) error {
var filled bool
var required int
var requiredFilled int
Expand Down Expand Up @@ -536,7 +544,8 @@
}

func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
if usingDifferentKeys(u.key, field) {
return nil
}
Expand All @@ -549,13 +558,16 @@
}

func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value reflect.Value,
vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error {
vp valueWithParent, opts *fieldOptionsWithContext, fullName string,
) error {
derefedFieldType := Deref(fieldType)
typeKind := derefedFieldType.Kind()
mapValue := vp.value
valueKind := reflect.TypeOf(mapValue).Kind()

switch {
case valueKind == reflect.String && typeKind == reflect.Struct && fieldType.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()):
return u.fillCustomUnmarshalerStruct(fieldType, value, mapValue)

Check warning on line 570 in core/mapping/unmarshaler.go

View check run for this annotation

Codecov / codecov/patch

core/mapping/unmarshaler.go#L569-L570

Added lines #L569 - L570 were not covered by tests
case valueKind == reflect.Map && typeKind == reflect.Struct:
mv, ok := mapValue.(map[string]any)
if !ok {
Expand All @@ -581,8 +593,24 @@
}
}

func (u *Unmarshaler) fillCustomUnmarshalerStruct(fieldType reflect.Type, value reflect.Value, mapValue any) error {
if !value.CanSet() {
return errValueNotSettable

Check warning on line 598 in core/mapping/unmarshaler.go

View check run for this annotation

Codecov / codecov/patch

core/mapping/unmarshaler.go#L596-L598

Added lines #L596 - L598 were not covered by tests
}
baseType := Deref(fieldType)
target := reflect.New(baseType)

Check warning on line 601 in core/mapping/unmarshaler.go

View check run for this annotation

Codecov / codecov/patch

core/mapping/unmarshaler.go#L600-L601

Added lines #L600 - L601 were not covered by tests

params := make([]reflect.Value, 1)
params[0] = reflect.ValueOf([]byte(mapValue.(string)))
target.MethodByName("UnmarshalJSON").Call(params)

Check warning on line 605 in core/mapping/unmarshaler.go

View check run for this annotation

Codecov / codecov/patch

core/mapping/unmarshaler.go#L603-L605

Added lines #L603 - L605 were not covered by tests

value.Set(target)
return nil

Check warning on line 608 in core/mapping/unmarshaler.go

View check run for this annotation

Codecov / codecov/patch

core/mapping/unmarshaler.go#L607-L608

Added lines #L607 - L608 were not covered by tests
}

func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflect.Value,
mapValue any, opts *fieldOptionsWithContext, fullName string) error {
mapValue any, opts *fieldOptionsWithContext, fullName string,
) error {
typeKind := Deref(fieldType).Kind()
valueKind := reflect.TypeOf(mapValue).Kind()

Expand All @@ -603,7 +631,8 @@
}

func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type, value reflect.Value,
v json.Number, opts *fieldOptionsWithContext, fullName string) error {
v json.Number, opts *fieldOptionsWithContext, fullName string,
) error {
baseType := Deref(fieldType)
typeKind := baseType.Kind()

Expand Down Expand Up @@ -656,7 +685,8 @@
}

func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
if fieldType.Kind() == reflect.Ptr {
baseType := Deref(fieldType)
target := reflect.New(baseType).Elem()
Expand All @@ -673,7 +703,8 @@
}

func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value reflect.Value,
mapValue any) (bool, error) {
mapValue any,
) (bool, error) {
var tval encoding.TextUnmarshaler
var ok bool

Expand Down Expand Up @@ -701,7 +732,8 @@
}

func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value reflect.Value,
envVal string, opts *fieldOptionsWithContext, fullName string) error {
envVal string, opts *fieldOptionsWithContext, fullName string,
) error {
if err := validateValueInOptions(envVal, opts.options()); err != nil {
return err
}
Expand Down Expand Up @@ -731,7 +763,8 @@
}

func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
if !field.IsExported() {
return nil
}
Expand Down Expand Up @@ -778,7 +811,8 @@
}

func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value reflect.Value,
vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string) error {
vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string,
) error {
mapValue := vp.value
if mapValue == nil {
if opts.optional() {
Expand Down Expand Up @@ -813,7 +847,8 @@
}

func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Type, value reflect.Value,
mapValue any, key string, opts *fieldOptionsWithContext, fullName string) error {
mapValue any, key string, opts *fieldOptionsWithContext, fullName string,
) error {
valueKind := reflect.TypeOf(mapValue).Kind()
if valueKind != reflect.String {
return fmt.Errorf("the value in map is not string, but %s", valueKind)
Expand Down Expand Up @@ -842,7 +877,8 @@
}

func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, value reflect.Value,
opts *fieldOptionsWithContext, fullName string) error {
opts *fieldOptionsWithContext, fullName string,
) error {
derefedType := Deref(fieldType)
fieldKind := derefedType.Kind()
if defaultValue, ok := opts.getDefault(); ok {
Expand Down Expand Up @@ -984,7 +1020,8 @@
}

func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue any,
opts *fieldOptionsWithContext, fullName string) error {
opts *fieldOptionsWithContext, fullName string,
) error {
if !value.CanSet() {
return errValueNotSettable
}
Expand Down Expand Up @@ -1013,7 +1050,8 @@
}

func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue any,
opts *fieldOptionsWithContext) error {
opts *fieldOptionsWithContext,
) error {
if !value.CanSet() {
return errValueNotSettable
}
Expand Down
21 changes: 21 additions & 0 deletions rest/httpx/requests_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package httpx

import (
"bytes"
"errors"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -515,3 +516,23 @@ func (m mockRequest) Validate() error {

return nil
}

type mockCustomUnmarshalerStruct struct {
Name string
}

func (m *mockCustomUnmarshalerStruct) UnmarshalJSON(b []byte) error {
m.Name = string(b)
return nil
}

func TestCustomUnmarshalerStructRequest(t *testing.T) {
reqBody := `{"name": "hello"}`
r := httptest.NewRequest(http.MethodPost, "/a", bytes.NewReader([]byte(reqBody)))
r.Header.Set("Content-Type", "application/json")
v := struct {
Foo *mockCustomUnmarshalerStruct `json:"name"`
}{}
assert.Nil(t, Parse(r, &v))
assert.Equal(t, "hello", v.Foo.Name)
}