diff --git a/handler/handler.go b/handler/handler.go index 066d7f0..5277f8d 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -136,42 +136,42 @@ func (fi *FuncInfo) Wrap() Func { // Construct a function to unpack the parameters from the request message, // based on the signature of the user's callback. - var newInput func(req *jrpc2.Request) ([]reflect.Value, error) + var newInput func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) if fi.Argument == nil { // Case 1: The function does not want any request parameters. // Nothing needs to be decoded, but verify no parameters were passed. - newInput = func(req *jrpc2.Request) ([]reflect.Value, error) { + newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { if req.HasParams() { return nil, jrpc2.Errorf(code.InvalidParams, "no parameters accepted") } - return nil, nil + return []reflect.Value{ctx}, nil } } else if fi.Argument == reqType { // Case 2: The function wants the underlying *jrpc2.Request value. - newInput = func(req *jrpc2.Request) ([]reflect.Value, error) { - return []reflect.Value{reflect.ValueOf(req)}, nil + newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { + return []reflect.Value{ctx, reflect.ValueOf(req)}, nil } } else if fi.Argument.Kind() == reflect.Ptr { // Case 3a: The function wants a pointer to its argument value. - newInput = func(req *jrpc2.Request) ([]reflect.Value, error) { + newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { in := reflect.New(fi.Argument.Elem()) if err := req.UnmarshalParams(in.Interface()); err != nil { return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err) } - return []reflect.Value{in}, nil + return []reflect.Value{ctx, in}, nil } } else { // Case 3b: The function wants a bare argument value. - newInput = func(req *jrpc2.Request) ([]reflect.Value, error) { + newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) { in := reflect.New(fi.Argument) // we still need a pointer to unmarshal if err := req.UnmarshalParams(in.Interface()); err != nil { return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err) } // Indirect the pointer back off for the callee. - return []reflect.Value{in.Elem()}, nil + return []reflect.Value{ctx, in.Elem()}, nil } } @@ -209,11 +209,10 @@ func (fi *FuncInfo) Wrap() Func { } return Func(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) { - rest, ierr := newInput(req) + args, ierr := newInput(reflect.ValueOf(ctx), req) if ierr != nil { return nil, ierr } - args := append([]reflect.Value{reflect.ValueOf(ctx)}, rest...) return decodeOut(call(args)) }) }