Skip to content

Commit

Permalink
Rework the jrpc2.ParseRequests function. (#81)
Browse files Browse the repository at this point in the history
Add a new ParsedRequest type and return that instead of Request.
This is a breaking change to the jrpc2.ParseRequests function.

The Request is meant for consumption by Handler methods where
validation has already been handled. The new type exposes validation errors,
allowing the caller to detect specific problems with each request in a batch.

This change also allows us to simplify checking of version errors, and
to remove most of the special case version checks in the client and server.

- Update the jhttp.Bridge to use the new ParsedRequest type.
- Update test cases where affected. Add a testutil package for internal tests.
- Fix handling of invalid requests in jhttp.Bridge (fixes #80).
  • Loading branch information
creachadair authored Feb 20, 2022
1 parent 90e8e7f commit 2377525
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 147 deletions.
25 changes: 9 additions & 16 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"sync"
Expand Down Expand Up @@ -150,22 +149,18 @@ func (c *Client) deliver(rsp *jmessage) {
}

id := string(fixID(rsp.ID))
if p := c.pending[id]; p == nil {
p := c.pending[id]
if p == nil {
c.log("Discarding response for unknown ID %q", id)
} else if !c.versionOK(rsp.V) {
delete(c.pending, id)
p.ch <- &jmessage{
ID: rsp.ID,
E: &Error{
Code: code.InvalidRequest,
Message: fmt.Sprintf("incorrect version marker %q", rsp.V),
},
}
return
}
// Remove the pending request from the set and deliver its response.
// Determining whether it's an error is the caller's responsibility.
delete(c.pending, id)
if rsp.err != nil {
p.ch <- &jmessage{ID: rsp.ID, E: rsp.err}
c.log("Invalid response for ID %q", id)
} else {
// Remove the pending request from the set and deliver its response.
// Determining whether it's an error is the caller's responsibility.
delete(c.pending, id)
p.ch <- rsp
c.log("Completed request for ID %q", id)
}
Expand Down Expand Up @@ -422,8 +417,6 @@ func (c *Client) stop(err error) {
c.ch = nil
}

func (c *Client) versionOK(v string) bool { return v == Version }

// marshalParams validates and marshals params to JSON for a request. The
// value of params must be either nil or encodable as a JSON object or array.
func (c *Client) marshalParams(ctx context.Context, method string, params interface{}) (json.RawMessage, error) {
Expand Down
3 changes: 3 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ var errEmptyBatch = &Error{Code: code.InvalidRequest, Message: "empty request ba
// errInvalidParams is the error reported for invalid request parameters.
var errInvalidParams = &Error{Code: code.InvalidParams, Message: code.InvalidParams.String()}

// errTaskNotExecuted is the internal sentinel error for an unassigned task.
var errTaskNotExecuted = new(Error)

// ErrConnClosed is returned by a server's push-to-client methods if they are
// called after the client connection is closed.
var ErrConnClosed = errors.New("client connection is closed")
Expand Down
14 changes: 7 additions & 7 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/code"
"github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/internal/testutil"
"github.com/creachadair/jrpc2/server"
)

Expand Down Expand Up @@ -98,10 +99,9 @@ func ExampleClient_Batch() {

func ExampleRequest_UnmarshalParams() {
const msg = `{"jsonrpc":"2.0", "id":101, "method":"M", "params":{"a":1, "b":2, "c":3}}`

reqs, err := jrpc2.ParseRequests([]byte(msg))
req, err := testutil.ParseRequest(msg)
if err != nil {
log.Fatalf("ParseRequests: %v", err)
log.Fatalf("Parsing %#q failed: %v", msg, err)
}

var t, u struct {
Expand All @@ -110,29 +110,29 @@ func ExampleRequest_UnmarshalParams() {
}

// By default, unmarshaling ignores unknown fields (here, "c").
if err := reqs[0].UnmarshalParams(&t); err != nil {
if err := req.UnmarshalParams(&t); err != nil {
log.Fatalf("UnmarshalParams: %v", err)
}
fmt.Printf("t.A=%d, t.B=%d\n", t.A, t.B)

// To implement strict field checking, there are several options:
//
// Solution 1: Use the jrpc2.StrictFields helper.
err = reqs[0].UnmarshalParams(jrpc2.StrictFields(&t))
err = req.UnmarshalParams(jrpc2.StrictFields(&t))
if code.FromError(err) != code.InvalidParams {
log.Fatalf("UnmarshalParams strict: %v", err)
}

// Solution 2: Implement a DisallowUnknownFields method.
var p strictParams
err = reqs[0].UnmarshalParams(&p)
err = req.UnmarshalParams(&p)
if code.FromError(err) != code.InvalidParams {
log.Fatalf("UnmarshalParams strict: %v", err)
}

// Solution 3: Decode the raw message separately.
var tmp json.RawMessage
reqs[0].UnmarshalParams(&tmp) // cannot fail
req.UnmarshalParams(&tmp) // cannot fail
dec := json.NewDecoder(bytes.NewReader(tmp))
dec.DisallowUnknownFields()
if err := dec.Decode(&u); err == nil {
Expand Down
9 changes: 4 additions & 5 deletions handler/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/internal/testutil"
)

func ExampleCheck() {
Expand Down Expand Up @@ -111,11 +112,9 @@ func ExamplePositional_array() {
}

func mustParseReq(s string) *jrpc2.Request {
reqs, err := jrpc2.ParseRequests([]byte(s))
req, err := testutil.ParseRequest(s)
if err != nil {
log.Fatalf("ParseRequests: %v", err)
} else if len(reqs) == 0 {
log.Fatal("ParseRequests: empty result")
log.Fatalf("ParseRequest: %v", err)
}
return reqs[0]
return req
}
20 changes: 5 additions & 15 deletions handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/code"
"github.com/creachadair/jrpc2/handler"
"github.com/creachadair/jrpc2/internal/testutil"
"github.com/google/go-cmp/cmp"
)

Expand Down Expand Up @@ -105,7 +106,7 @@ func TestFuncInfo_wrapDecode(t *testing.T) {
}
ctx := context.Background()
for _, test := range tests {
req := mustParseRequest(t,
req := testutil.MustParseRequest(t,
fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"x","params":%s}`, test.p))
got, err := test.fn(ctx, req)
if err != nil {
Expand Down Expand Up @@ -165,7 +166,7 @@ func TestNewStrict(t *testing.T) {
}
fn := handler.NewStrict(func(ctx context.Context, arg *arg) error { return nil })

req := mustParseRequest(t, `{
req := testutil.MustParseRequest(t, `{
"jsonrpc": "2.0",
"id": 100,
"method": "f",
Expand Down Expand Up @@ -199,7 +200,7 @@ func TestNew_pointerRegression(t *testing.T) {
t.Logf("Got argument struct: %+v", got)
return nil
})
req := mustParseRequest(t, `{
req := testutil.MustParseRequest(t, `{
"jsonrpc": "2.0",
"id": "foo",
"method": "bar",
Expand Down Expand Up @@ -245,7 +246,7 @@ func TestPositional_decode(t *testing.T) {
{`{"jsonrpc":"2.0","id":15,"method":"add","params":[1,2,3]}`, 0, true}, // too many
}
for _, test := range tests {
req := mustParseRequest(t, test.input)
req := testutil.MustParseRequest(t, test.input)
got, err := call(context.Background(), req)
if !test.bad {
if err != nil {
Expand Down Expand Up @@ -438,17 +439,6 @@ func TestObjUnmarshal(t *testing.T) {
}
}

func mustParseRequest(t *testing.T, text string) *jrpc2.Request {
t.Helper()
req, err := jrpc2.ParseRequests([]byte(text))
if err != nil {
t.Fatalf("ParseRequests: %v", err)
} else if len(req) != 1 {
t.Fatalf("Wrong number of requests: got %d, want 1", len(req))
}
return req[0]
}

// stringByte is a byte with a custom JSON encoding. It expects a string of
// decimal digits 1 and 0, e.g., "10011000" == 0x98.
type stringByte byte
Expand Down
68 changes: 68 additions & 0 deletions internal/testutil/testutil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (C) 2022 Michael J. Fromberger. All Rights Reserved.

// Package testutil defines internal support code for writing tests.
package testutil

import (
"context"
"testing"

"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/channel"
)

// ParseRequest parses a single JSON request object.
func ParseRequest(s string) (_ *jrpc2.Request, err error) {
// Check syntax.
if _, err := jrpc2.ParseRequests([]byte(s)); err != nil {
return nil, err
}

cch, sch := channel.Direct()
rs := newRequestStub()
srv := jrpc2.NewServer(rs, nil).Start(sch)
defer func() {
cch.Close()
serr := srv.Wait()
if err == nil {
err = serr
}
}()
if err := cch.Send([]byte(s)); err != nil {
return nil, err
}
req := <-rs.reqc
if !rs.isNote {
cch.Recv()
}
return req, nil
}

// MustParseRequest calls ParseRequest and fails t if it reports an error.
func MustParseRequest(t *testing.T, s string) *jrpc2.Request {
t.Helper()

req, err := ParseRequest(s)
if err != nil {
t.Fatalf("Parsing %#q failed: %v", s, err)
}
return req
}

func newRequestStub() *requestStub {
return &requestStub{reqc: make(chan *jrpc2.Request, 1)}
}

type requestStub struct {
reqc chan *jrpc2.Request
isNote bool
}

func (r *requestStub) Assign(context.Context, string) jrpc2.Handler { return r }

func (r *requestStub) Handle(_ context.Context, req *jrpc2.Request) (interface{}, error) {
defer close(r.reqc)
r.isNote = req.IsNotification()
r.reqc <- req
return nil, nil
}
28 changes: 28 additions & 0 deletions internal/testutil/testutil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2022 Michael J. Fromberger. All Rights Reserved.

package testutil_test

import (
"testing"

"github.com/creachadair/jrpc2/internal/testutil"
)

func TestParseRequest(t *testing.T) {
t.Run("Invalid", func(t *testing.T) {
req, err := testutil.ParseRequest(`{this is invalid}`)
if err == nil {
t.Errorf("ParseRequest: got %+v, wanted error", req)
} else {
t.Logf("Invalid OK: %v", err)
}
})
t.Run("Call", func(t *testing.T) {
req := testutil.MustParseRequest(t, `{"jsonrpc":"2.0","id":1,"method":"OK"}`)
t.Logf("Call OK: %+v", req)
})
t.Run("Notification", func(t *testing.T) {
req := testutil.MustParseRequest(t, `{"jsonrpc":"2.0","id":null,"method":"OK"}`)
t.Logf("Note OK: %+v", req)
})
}
48 changes: 28 additions & 20 deletions internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,47 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
)

var errInvalidVersion = &Error{Code: code.InvalidRequest, Message: "invalid version marker"}

func TestParseRequests(t *testing.T) {
tests := []struct {
input string
want []*Request
want []*ParsedRequest
err error
}{
// An empty batch is valid and produces no results.
{`[]`, nil, nil},

// An empty single request is invalid but returned anyway.
{`{}`, []*Request{{}}, ErrInvalidVersion},
{`{}`, []*ParsedRequest{{Error: errInvalidVersion}}, nil},

// Structurally invalid JSON reports an error.
{`[`, nil, errInvalidRequest},
{`{}}`, nil, errInvalidRequest},
{`[{}, ttt]`, nil, errInvalidRequest},

// A valid notification.
{`{"jsonrpc":"2.0", "method": "foo", "params":[1, 2, 3]}`, []*Request{{
method: "foo",
params: json.RawMessage(`[1, 2, 3]`),
{`{"jsonrpc":"2.0", "method": "foo", "params":[1, 2, 3]}`, []*ParsedRequest{{
Method: "foo",
Params: json.RawMessage(`[1, 2, 3]`),
}}, nil},

// A valid request, with nil parameters.
{`{"jsonrpc":"2.0", "method": "foo", "id":10332, "params":null}`, []*Request{{
id: json.RawMessage("10332"), method: "foo",
{`{"jsonrpc":"2.0", "method": "foo", "id":10332, "params":null}`, []*ParsedRequest{{
ID: "10332", Method: "foo",
}}, nil},

// A valid mixed batch.
{`[ {"jsonrpc": "2.0", "id": 1, "method": "A", "params": {}},
{"jsonrpc": "2.0", "params": [5], "method": "B"} ]`, []*Request{
{method: "A", id: json.RawMessage(`1`), params: json.RawMessage(`{}`)},
{method: "B", params: json.RawMessage(`[5]`)},
{"jsonrpc": "2.0", "params": [5], "method": "B"} ]`, []*ParsedRequest{
{Method: "A", ID: "1", Params: json.RawMessage(`{}`)},
{Method: "B", Params: json.RawMessage(`[5]`)},
}, nil},

// An invalid batch.
{`[{"id": 37, "method": "complain", "params":[]}]`, []*Request{
{method: "complain", id: json.RawMessage(`37`), params: json.RawMessage(`[]`)},
}, ErrInvalidVersion},
{`[{"id": 37, "method": "complain", "params":[]}]`, []*ParsedRequest{
{Method: "complain", ID: "37", Params: json.RawMessage(`[]`), Error: errInvalidVersion},
}, nil},

// A broken request.
{`{`, nil, Errorf(code.ParseError, "invalid request value")},
Expand All @@ -67,7 +74,7 @@ func TestParseRequests(t *testing.T) {
continue
}

diff := cmp.Diff(test.want, got, cmp.AllowUnexported(Request{}), cmpopts.EquateEmpty())
diff := cmp.Diff(test.want, got, cmpopts.EquateEmpty())
if diff != "" {
t.Errorf("ParseRequests(%#q): wrong result (-want, +got):\n%s", test.input, diff)
}
Expand Down Expand Up @@ -113,17 +120,18 @@ func TestRequest_UnmarshalParams(t *testing.T) {
xy{X: 23}, `{"x":23, "z":"wat"}`, code.NoError},
}
for _, test := range tests {
req, err := ParseRequests([]byte(test.input))
if err != nil {
var reqs jmessages
if err := reqs.parseJSON([]byte(test.input)); err != nil {
t.Errorf("Parsing request %#q failed: %v", test.input, err)
} else if len(req) != 1 {
t.Fatalf("Wrong number of requests: got %d, want 1", len(req))
} else if len(reqs) != 1 {
t.Fatalf("Wrong number of requests: got %d, want 1", len(reqs))
}
req := &Request{id: reqs[0].ID, method: reqs[0].M, params: reqs[0].P}

// Allocate a zero of the expected type to unmarshal into.
target := reflect.New(reflect.TypeOf(test.want)).Interface()
{
err := req[0].UnmarshalParams(target)
err := req.UnmarshalParams(target)
if got := code.FromError(err); got != test.code {
t.Errorf("UnmarshalParams error: got code %d, want %d [%v]", got, test.code, err)
}
Expand All @@ -139,7 +147,7 @@ func TestRequest_UnmarshalParams(t *testing.T) {
}

// Check that the parameter string matches.
if got := req[0].ParamString(); got != test.pstring {
if got := req.ParamString(); got != test.pstring {
t.Errorf("ParamString(%#q): got %q, want %q", test.input, got, test.pstring)
}
}
Expand Down
Loading

0 comments on commit 2377525

Please sign in to comment.