Skip to content

Commit

Permalink
EVM-437 Batch calls over websockets not working properly (#1588)
Browse files Browse the repository at this point in the history
  • Loading branch information
igorcrevar authored Jun 7, 2023
1 parent 5318ff4 commit c1d3a79
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 52 deletions.
2 changes: 2 additions & 0 deletions jsonrpc/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type Request struct {
Params json.RawMessage `json:"params,omitempty"`
}

type BatchRequest []Request

// Response is a jsonrpc response interface
type Response interface {
GetID() interface{}
Expand Down
118 changes: 80 additions & 38 deletions jsonrpc/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"math"
"reflect"
"strconv"
"strings"
"time"
"unicode"
Expand Down Expand Up @@ -60,6 +61,10 @@ type dispatcherParams struct {
blockRangeLimit uint64
}

func (dp dispatcherParams) isExceedingBatchLengthLimit(value uint64) bool {
return dp.jsonRPCBatchLengthLimit != 0 && value > dp.jsonRPCBatchLengthLimit
}

func newDispatcher(
logger hclog.Logger,
store JSONRPCStore,
Expand Down Expand Up @@ -163,22 +168,23 @@ type wsConn interface {

// as per https://www.jsonrpc.org/specification, the `id` in JSON-RPC 2.0
// can only be a string or a non-decimal integer
func formatFilterResponse(id interface{}, resp string) (string, Error) {
func formatID(id interface{}) (interface{}, Error) {
switch t := id.(type) {
case string:
return fmt.Sprintf(`{"jsonrpc":"2.0","id":"%s","result":"%s"}`, t, resp), nil
return t, nil
case float64:
if t == math.Trunc(t) {
return fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"result":"%s"}`, int(t), resp), nil
return int(t), nil
} else {
return "", NewInvalidRequestError("Invalid json request")
}
case nil:
return fmt.Sprintf(`{"jsonrpc":"2.0","id":null,"result":"%s"}`, resp), nil
return nil, nil
default:
return "", NewInvalidRequestError("Invalid json request")
}
}

func (d *Dispatcher) handleSubscribe(req Request, conn wsConn) (string, Error) {
var params []interface{}
if err := json.Unmarshal(req.Params, &params); err != nil {
Expand Down Expand Up @@ -233,54 +239,90 @@ func (d *Dispatcher) RemoveFilterByWs(conn wsConn) {
}

func (d *Dispatcher) HandleWs(reqBody []byte, conn wsConn) ([]byte, error) {
var req Request
if err := json.Unmarshal(reqBody, &req); err != nil {
return NewRPCResponse(req.ID, "2.0", nil, NewInvalidRequestError("Invalid json request")).Bytes()
}
const (
openSquareBracket byte = '['
closeSquareBracket byte = ']'
comma byte = ','
)

// if the request method is eth_subscribe we need to create a
// new filter with ws connection
if req.Method == "eth_subscribe" {
filterID, err := d.handleSubscribe(req, conn)
if err != nil {
return NewRPCResponse(req.ID, "2.0", nil, err).Bytes()
}
reqBody = bytes.TrimLeft(reqBody, " \t\r\n")

resp, err := formatFilterResponse(req.ID, filterID)
// if body begins with [ consider it as a batch request
if len(reqBody) > 0 && reqBody[0] == openSquareBracket {
var batchReq BatchRequest

err := json.Unmarshal(reqBody, &batchReq)
if err != nil {
return NewRPCResponse(req.ID, "2.0", nil, err).Bytes()
return NewRPCResponse(nil, "2.0", nil,
NewInvalidRequestError("Invalid json batch request")).Bytes()
}

return []byte(resp), nil
}

if req.Method == "eth_unsubscribe" {
ok, err := d.handleUnsubscribe(req)
if err != nil {
return nil, err
// if not disabled, avoid handling long batch requests
if d.params.isExceedingBatchLengthLimit(uint64(len(batchReq))) {
return NewRPCResponse(
nil,
"2.0",
nil,
NewInvalidRequestError("Batch request length too long"),
).Bytes()
}

res := "false"
if ok {
res = "true"
}
responses := make([][]byte, len(batchReq))

resp, err := formatFilterResponse(req.ID, res)
if err != nil {
return NewRPCResponse(req.ID, "2.0", nil, err).Bytes()
for i, req := range batchReq {
responses[i], err = d.handleSingleWs(req, conn).Bytes()
if err != nil {
return nil, err
}
}

return []byte(resp), nil
var buf bytes.Buffer

// batch output should look like:
// [ { "requestId": "1", "status": 200 }, { "requestId": "2", "status": 200 } ]
buf.WriteByte(openSquareBracket) // [
buf.Write(bytes.Join(responses, []byte{comma})) // join responses with the comma separator
buf.WriteByte(closeSquareBracket) // ]

return buf.Bytes(), nil
}

// its a normal query that we handle with the dispatcher
resp, err := d.handleReq(req)
var req Request
if err := json.Unmarshal(reqBody, &req); err != nil {
return NewRPCResponse(req.ID, "2.0", nil, NewInvalidRequestError("Invalid json request")).Bytes()
}

return d.handleSingleWs(req, conn).Bytes()
}

func (d *Dispatcher) handleSingleWs(req Request, conn wsConn) Response {
id, err := formatID(req.ID)
if err != nil {
return nil, err
return NewRPCResponse(nil, "2.0", nil, err)
}

var response []byte

switch req.Method {
case "eth_subscribe":
var filterID string

// if the request method is eth_subscribe we need to create a new filter with ws connection
if filterID, err = d.handleSubscribe(req, conn); err == nil {
response = []byte(fmt.Sprintf("\"%s\"", filterID))
}
case "eth_unsubscribe":
var ok bool

if ok, err = d.handleUnsubscribe(req); err == nil {
response = []byte(strconv.FormatBool(ok))
}
default:
// its a normal query that we handle with the dispatcher
response, err = d.handleReq(req)
}

return NewRPCResponse(req.ID, "2.0", resp, err).Bytes()
return NewRPCResponse(id, "2.0", response, err)
}

func (d *Dispatcher) Handle(reqBody []byte) ([]byte, error) {
Expand All @@ -305,7 +347,7 @@ func (d *Dispatcher) Handle(reqBody []byte) ([]byte, error) {
}

// handle batch requests
var requests []Request
var requests BatchRequest
if err := json.Unmarshal(reqBody, &requests); err != nil {
return NewRPCResponse(
nil,
Expand All @@ -316,7 +358,7 @@ func (d *Dispatcher) Handle(reqBody []byte) ([]byte, error) {
}

// if not disabled, avoid handling long batch requests
if d.params.jsonRPCBatchLengthLimit != 0 && len(requests) > int(d.params.jsonRPCBatchLengthLimit) {
if d.params.isExceedingBatchLengthLimit(uint64(len(requests))) {
return NewRPCResponse(
nil,
"2.0",
Expand Down
106 changes: 92 additions & 14 deletions jsonrpc/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jsonrpc

import (
"encoding/json"
"fmt"
"math/big"
"reflect"
"testing"
Expand Down Expand Up @@ -102,6 +103,8 @@ func TestDispatcher_HandleWebsocketConnection_EthSubscribe(t *testing.T) {
}

func TestDispatcher_WebsocketConnection_RequestFormats(t *testing.T) {
t.Parallel()

store := newMockStore()
dispatcher := newTestDispatcher(t,
hclog.NewNullLogger(),
Expand Down Expand Up @@ -212,6 +215,8 @@ func (m *mockService) Filter(f LogQuery) (interface{}, error) {
}

func TestDispatcherFuncDecode(t *testing.T) {
t.Parallel()

srv := &mockService{msgCh: make(chan interface{}, 10)}

dispatcher := newTestDispatcher(t,
Expand Down Expand Up @@ -290,20 +295,29 @@ func TestDispatcherFuncDecode(t *testing.T) {
}

func TestDispatcherBatchRequest(t *testing.T) {
handle := func(dispatcher *Dispatcher, reqBody []byte) []byte {
res, _ := dispatcher.Handle(reqBody)

return res
}
t.Parallel()

cases := []struct {
type caseData struct {
name string
desc string
dispatcher *Dispatcher
reqBody []byte
err *ObjectError
batchResponse []*SuccessResponse
}{
}

mock := &mockWsConn{
SetFilterIDFn: func(s string) {
},
GetFilterIDFn: func() string {
return ""
},
WriteMessageFn: func(i int, b []byte) error {
return nil
},
}

cases := []caseData{
{
"leading-whitespace",
"test with leading whitespace (\" \\t\\n\\n\\r\\)",
Expand Down Expand Up @@ -425,36 +439,100 @@ func TestDispatcherBatchRequest(t *testing.T) {
},
}

for _, c := range cases {
res := handle(c.dispatcher, c.reqBody)

check := func(c caseData, res []byte) {
if c.err != nil {
var resp ErrorResponse

assert.NoError(t, expectBatchJSONResult(res, &resp))
assert.Equal(t, resp.Error, c.err)
assert.Equal(t, c.err, resp.Error)
} else {
var batchResp []SuccessResponse
assert.NoError(t, expectBatchJSONResult(res, &batchResp))

if c.name == "leading-whitespace" {
assert.Len(t, batchResp, 4)
for index, resp := range batchResp {
assert.Equal(t, resp.Error, c.batchResponse[index].Error)
assert.Equal(t, c.batchResponse[index].Error, resp.Error)
}
} else if c.name == "valid-batch-req" {
assert.Len(t, batchResp, 6)
for index, resp := range batchResp {
assert.Equal(t, resp.Error, c.batchResponse[index].Error)
assert.Equal(t, c.batchResponse[index].Error, resp.Error)
}
} else if c.name == "no-limits" {
assert.Len(t, batchResp, 12)
for index, resp := range batchResp {
assert.Equal(t, resp.Error, c.batchResponse[index].Error)
assert.Equal(t, c.batchResponse[index].Error, resp.Error)
}
}
}
}

for _, c := range cases {
c := c

t.Run(c.name, func(t *testing.T) {
t.Parallel()

res, _ := c.dispatcher.HandleWs(c.reqBody, mock)

check(c, res)

res, _ = c.dispatcher.Handle(c.reqBody)

check(c, res)
})
}
}

func TestDispatcher_WebsocketConnection_Unsubscribe(t *testing.T) {
t.Parallel()

store := newMockStore()
dispatcher := newTestDispatcher(t,
hclog.NewNullLogger(),
store,
&dispatcherParams{
chainID: 0,
priceLimit: 0,
jsonRPCBatchLengthLimit: 20,
blockRangeLimit: 1000,
},
)
mockConn := &mockWsConn{
SetFilterIDFn: func(s string) {
},
GetFilterIDFn: func() string {
return ""
},
WriteMessageFn: func(i int, b []byte) error {
return nil
},
}

resp := SuccessResponse{}
reqUnsub := func(n string) []byte {
return []byte(fmt.Sprintf(`{"method": "eth_unsubscribe", "params": [%s]}`, n))
}

// non existing subscription
r, err := dispatcher.HandleWs(reqUnsub("\"787832\""), mockConn)
require.NoError(t, err)

require.NoError(t, json.Unmarshal(r, &resp))
assert.Equal(t, "false", string(resp.Result))

r, err = dispatcher.HandleWs([]byte(`{"method": "eth_subscribe", "params": ["newHeads"]}`), mockConn)
require.NoError(t, err)

require.NoError(t, json.Unmarshal(r, &resp))

// existing subscription
r, err = dispatcher.HandleWs(reqUnsub(string(resp.Result)), mockConn)
require.NoError(t, err)

require.NoError(t, json.Unmarshal(r, &resp))
assert.Equal(t, "true", string(resp.Result))
}

func newTestDispatcher(t *testing.T, logger hclog.Logger, store JSONRPCStore, params *dispatcherParams) *Dispatcher {
Expand Down

0 comments on commit c1d3a79

Please sign in to comment.