Skip to content

Commit

Permalink
feat: httpx.Parse supports parsing structures that implement the Unma…
Browse files Browse the repository at this point in the history
…rshaler interface
  • Loading branch information
lyuangg committed May 11, 2024
1 parent 7c730b9 commit 8ea6163
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 21 deletions.
80 changes: 59 additions & 21 deletions core/mapping/unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
}

func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value,
mapValue any, fullName string) error {
mapValue any, fullName string,
) error {
var slice []any
switch v := mapValue.(type) {
case fmt.Stringer:
Expand Down Expand Up @@ -248,7 +249,8 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
}

func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
baseKind reflect.Kind, value any, fullName string) error {
baseKind reflect.Kind, value any, fullName string,
) error {
if value == nil {
return errNilSliceElement
}
Expand Down Expand Up @@ -286,7 +288,8 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
}

func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value,
defaultValue, fullName string) error {
defaultValue, fullName string,
) error {
baseFieldType := Deref(derefedType.Elem())
baseFieldKind := baseFieldType.Kind()
defaultCacheLock.Lock()
Expand Down Expand Up @@ -400,7 +403,8 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any,
}

func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Valuer, fullName string) (
string, *fieldOptionsWithContext, error) {
string, *fieldOptionsWithContext, error,
) {
key, options, err := parseKeyAndOptions(u.key, field)
if err != nil {
return "", nil, err
Expand Down Expand Up @@ -441,7 +445,8 @@ func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Value
}

func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
key, options, err := u.parseOptionsWithContext(field, m, fullName)
if err != nil {
return err
Expand All @@ -459,7 +464,8 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref
}

func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value,
key string, m valuerWithParent, fullName string) error {
key string, m valuerWithParent, fullName string,
) error {
derefedFieldType := Deref(field.Type)

switch derefedFieldType.Kind() {
Expand All @@ -471,7 +477,8 @@ func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, v
}

func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
fieldType := field.Type
maybeNewValue(fieldType, value)
derefedFieldType := Deref(fieldType)
Expand All @@ -495,7 +502,8 @@ func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, v
}

func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type,
value reflect.Value, key string, m valuerWithParent, fullName string) error {
value reflect.Value, key string, m valuerWithParent, fullName string,
) error {
var filled bool
var required int
var requiredFilled int
Expand Down Expand Up @@ -536,7 +544,8 @@ func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type
}

func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
if usingDifferentKeys(u.key, field) {
return nil
}
Expand All @@ -549,13 +558,16 @@ func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Valu
}

func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value reflect.Value,
vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error {
vp valueWithParent, opts *fieldOptionsWithContext, fullName string,
) error {
derefedFieldType := Deref(fieldType)
typeKind := derefedFieldType.Kind()
mapValue := vp.value
valueKind := reflect.TypeOf(mapValue).Kind()

switch {
case valueKind == reflect.String && typeKind == reflect.Struct && fieldType.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()):
return u.fillCustomUnmarshalerStruct(fieldType, value, mapValue)
case valueKind == reflect.Map && typeKind == reflect.Struct:
mv, ok := mapValue.(map[string]any)
if !ok {
Expand All @@ -581,8 +593,24 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
}
}

func (u *Unmarshaler) fillCustomUnmarshalerStruct(fieldType reflect.Type, value reflect.Value, mapValue any) error {
if !value.CanSet() {
return errValueNotSettable
}
baseType := Deref(fieldType)
target := reflect.New(baseType)

params := make([]reflect.Value, 1)
params[0] = reflect.ValueOf([]byte(mapValue.(string)))
target.MethodByName("UnmarshalJSON").Call(params)

value.Set(target)
return nil
}

func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflect.Value,
mapValue any, opts *fieldOptionsWithContext, fullName string) error {
mapValue any, opts *fieldOptionsWithContext, fullName string,
) error {
typeKind := Deref(fieldType).Kind()
valueKind := reflect.TypeOf(mapValue).Kind()

Expand All @@ -603,7 +631,8 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
}

func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type, value reflect.Value,
v json.Number, opts *fieldOptionsWithContext, fullName string) error {
v json.Number, opts *fieldOptionsWithContext, fullName string,
) error {
baseType := Deref(fieldType)
typeKind := baseType.Kind()

Expand Down Expand Up @@ -656,7 +685,8 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
}

func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
if fieldType.Kind() == reflect.Ptr {
baseType := Deref(fieldType)
target := reflect.New(baseType).Elem()
Expand All @@ -673,7 +703,8 @@ func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.V
}

func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value reflect.Value,
mapValue any) (bool, error) {
mapValue any,
) (bool, error) {
var tval encoding.TextUnmarshaler
var ok bool

Expand Down Expand Up @@ -701,7 +732,8 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value
}

func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value reflect.Value,
envVal string, opts *fieldOptionsWithContext, fullName string) error {
envVal string, opts *fieldOptionsWithContext, fullName string,
) error {
if err := validateValueInOptions(envVal, opts.options()); err != nil {
return err
}
Expand Down Expand Up @@ -731,7 +763,8 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
}

func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string) error {
m valuerWithParent, fullName string,
) error {
if !field.IsExported() {
return nil
}
Expand Down Expand Up @@ -778,7 +811,8 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
}

func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value reflect.Value,
vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string) error {
vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string,
) error {
mapValue := vp.value
if mapValue == nil {
if opts.optional() {
Expand Down Expand Up @@ -813,7 +847,8 @@ func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value r
}

func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Type, value reflect.Value,
mapValue any, key string, opts *fieldOptionsWithContext, fullName string) error {
mapValue any, key string, opts *fieldOptionsWithContext, fullName string,
) error {
valueKind := reflect.TypeOf(mapValue).Kind()
if valueKind != reflect.String {
return fmt.Errorf("the value in map is not string, but %s", valueKind)
Expand Down Expand Up @@ -842,7 +877,8 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
}

func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, value reflect.Value,
opts *fieldOptionsWithContext, fullName string) error {
opts *fieldOptionsWithContext, fullName string,
) error {
derefedType := Deref(fieldType)
fieldKind := derefedType.Kind()
if defaultValue, ok := opts.getDefault(); ok {
Expand Down Expand Up @@ -984,7 +1020,8 @@ func fillDurationValue(fieldType reflect.Type, value reflect.Value, dur string)
}

func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue any,
opts *fieldOptionsWithContext, fullName string) error {
opts *fieldOptionsWithContext, fullName string,
) error {
if !value.CanSet() {
return errValueNotSettable
}
Expand Down Expand Up @@ -1013,7 +1050,8 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue any,
}

func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue any,
opts *fieldOptionsWithContext) error {
opts *fieldOptionsWithContext,
) error {
if !value.CanSet() {
return errValueNotSettable
}
Expand Down
21 changes: 21 additions & 0 deletions rest/httpx/requests_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package httpx

import (
"bytes"
"errors"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -515,3 +516,23 @@ func (m mockRequest) Validate() error {

return nil
}

type mockCustomUnmarshalerStruct struct {
Name string
}

func (m *mockCustomUnmarshalerStruct) UnmarshalJSON(b []byte) error {
m.Name = string(b)
return nil
}

func TestCustomUnmarshalerStructRequest(t *testing.T) {
reqBody := `{"name": "hello"}`
r := httptest.NewRequest(http.MethodPost, "/a", bytes.NewReader([]byte(reqBody)))
r.Header.Set("Content-Type", "application/json")
v := struct {
Foo *mockCustomUnmarshalerStruct `json:"name"`
}{}
assert.Nil(t, Parse(r, &v))
assert.Equal(t, "hello", v.Foo.Name)
}

0 comments on commit 8ea6163

Please sign in to comment.