Skip to content

Commit

Permalink
handler: rework field strictness and array decoding
Browse files Browse the repository at this point in the history
Instead of having these concerns tied together on a single type, separate it
into strictStub (field checking) and arrayStub (decoding).

Pull out the logic for assigning the argument wrapper into its own method, and
properly handle the nesting if both are selected (decoding is outer).
  • Loading branch information
creachadair committed Apr 24, 2022
1 parent e482ade commit c3b11f5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
60 changes: 42 additions & 18 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,15 @@ func (fi *FuncInfo) Wrap() Func {
return Func(f)
}

// If strict field checking is desired, ensure arguments are wrapped.
arg := fi.Argument
wrapArg := func(v reflect.Value) interface{} { return v.Interface() }
if fi.strictFields && arg != nil && !arg.Implements(strictType) {
names := fi.posNames
wrapArg = func(v reflect.Value) interface{} {
return &strict{v: v.Interface(), posNames: names}
}
}
// If strict field checking or positional decoding are enabled, ensure
// arguments are wrapped with the appropriate decoder stubs.
wrapArg := fi.argWrapper()

// Construct a function to unpack the parameters from the request message,
// based on the signature of the user's callback.
var newInput func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error)

arg := fi.Argument
if arg == nil {
// Case 1: The function does not want any request parameters.
// Nothing needs to be decoded, but verify no parameters were passed.
Expand Down Expand Up @@ -307,10 +302,9 @@ func Check(fn interface{}) (*FuncInfo, error) {
return info, nil
}

// strict is a wrapper for an arbitrary value that enforces strict field
// checking when unmarshaling from JSON, and handles translation of array
// format into object format.
type strict struct {
// arrayStub is a wrapper for an arbitrary value that handles translation of
// JSON arrays into a corresponding object format.
type arrayStub struct {
v interface{}
posNames []string
}
Expand All @@ -321,9 +315,9 @@ type strict struct {
// If s.posNames is set and data encodes an array, the array is rewritten to an
// equivalent object with field names assigned by the positional names.
// Otherwise, data is returned as-is without error.
func (s *strict) translate(data []byte) ([]byte, error) {
if len(s.posNames) == 0 || firstByte(data) != '[' {
return data, nil // no names, or not an array
func (s *arrayStub) translate(data []byte) ([]byte, error) {
if firstByte(data) != '[' {
return data, nil // not an array
}

// Decode the array wrapper and verify it has the correct length.
Expand All @@ -343,12 +337,42 @@ func (s *strict) translate(data []byte) ([]byte, error) {
return json.Marshal(obj)
}

func (s *strict) UnmarshalJSON(data []byte) error {
func (s *arrayStub) UnmarshalJSON(data []byte) error {
actual, err := s.translate(data)
if err != nil {
return err
}
dec := json.NewDecoder(bytes.NewReader(actual))
return json.Unmarshal(actual, s.v)
}

// strictStub is a wrapper for an arbitrary value that enforces strict field
// checking when unmarshaling from JSON.
type strictStub struct{ v interface{} }

func (s *strictStub) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewReader(data))
dec.DisallowUnknownFields()
return dec.Decode(s.v)
}

func (fi *FuncInfo) argWrapper() func(reflect.Value) interface{} {
strict := fi.strictFields && fi.Argument != nil && !fi.Argument.Implements(strictType)
names := fi.posNames // capture so the wrapper does not pin fi
array := len(names) != 0
switch {
case strict && array:
return func(v reflect.Value) interface{} {
return &arrayStub{v: &strictStub{v: v.Interface()}, posNames: names}
}
case strict:
return func(v reflect.Value) interface{} {
return &strictStub{v: v.Interface()}
}
case array:
return func(v reflect.Value) interface{} {
return &arrayStub{v: v.Interface(), posNames: names}
}
default:
return reflect.Value.Interface
}
}
4 changes: 2 additions & 2 deletions handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ func TestFuncInfo_wrapDecode(t *testing.T) {
fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"x","params":%s}`, test.p))
got, err := test.fn(ctx, req)
if err != nil {
t.Errorf("Call %v failed: %v", test.fn, err)
t.Errorf("Call %+v failed: %v", test.fn, err)
} else if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("Call %v: wrong result (-want, +got)\n%s", test.fn, diff)
t.Errorf("Call %+v: wrong result (-want, +got)\n%s", test.fn, diff)
}
}
}
Expand Down

0 comments on commit c3b11f5

Please sign in to comment.