Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support error codes / typed errors #36

Merged
merged 2 commits into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func NewClient(ctx context.Context, addr string, namespace string, handler inter
type client struct {
namespace string
paramEncoders map[reflect.Type]ParamEncoder
errors *Errors

doRequest func(context.Context, clientRequest) (clientResponse, error)
exiting <-chan struct{}
Expand Down Expand Up @@ -130,6 +131,7 @@ func httpClient(ctx context.Context, addr string, namespace string, outs []inter
c := client{
namespace: namespace,
paramEncoders: config.paramEncoders,
errors: config.errors,
}

stop := make(chan struct{})
Expand Down Expand Up @@ -212,6 +214,7 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
c := client{
namespace: namespace,
paramEncoders: config.paramEncoders,
errors: config.errors,
}

requests := make(chan clientRequest)
Expand Down Expand Up @@ -442,7 +445,8 @@ func (fn *rpcFunc) processResponse(resp clientResponse, rval reflect.Value) []re
if fn.errOut != -1 {
out[fn.errOut] = reflect.New(errorType).Elem()
if resp.Error != nil {
out[fn.errOut].Set(reflect.ValueOf(resp.Error))

out[fn.errOut].Set(resp.Error.val(fn.client.errors))
}
}

Expand Down Expand Up @@ -548,7 +552,7 @@ func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value)

retVal = func() reflect.Value { return val.Elem() }
}
retry := resp.Error != nil && resp.Error.Code == 2 && fn.retry
retry := resp.Error != nil && resp.Error.Code == eTempWSError && fn.retry
if !retry {
break
}
Expand Down
37 changes: 37 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package jsonrpc

import (
"encoding/json"
"reflect"
)

type Errors struct {
byType map[reflect.Type]ErrorCode
byCode map[ErrorCode]reflect.Type
}

type ErrorCode int

const FirstUserCode = 2

func NewErrors() Errors {
return Errors{
byType: map[reflect.Type]ErrorCode{},
byCode: map[ErrorCode]reflect.Type{},
}
}

func (e *Errors) Register(c ErrorCode, typ interface{}) {
rt := reflect.TypeOf(typ).Elem()
if !rt.Implements(errorType) {
panic("can't register non-error types")
}

e.byType[rt] = c
e.byCode[c] = rt
}

type marshalable interface {
json.Marshaler
json.Unmarshaler
}
62 changes: 55 additions & 7 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ type request struct {
const DEFAULT_MAX_REQUEST_SIZE = 100 << 20 // 100 MiB

type respError struct {
Code int `json:"code"`
Message string `json:"message"`
Code ErrorCode `json:"code"`
Message string `json:"message"`
Meta json.RawMessage `json:"meta,omitempty"`
}

func (e *respError) Error() string {
Expand All @@ -61,6 +62,31 @@ func (e *respError) Error() string {
return e.Message
}

var marshalableRT = reflect.TypeOf(new(marshalable)).Elem()

func (e *respError) val(errors *Errors) reflect.Value {
if errors != nil {
t, ok := errors.byCode[e.Code]
if ok {
var v reflect.Value
if t.Kind() == reflect.Ptr {
v = reflect.New(t.Elem())
} else {
v = reflect.New(t)
}
if len(e.Meta) > 0 && v.Type().Implements(marshalableRT) {
_ = v.Interface().(marshalable).UnmarshalJSON(e.Meta)
}
if t.Kind() != reflect.Ptr {
v = v.Elem()
}
return v
}
}

return reflect.ValueOf(e)
}

type response struct {
Jsonrpc string `json:"jsonrpc"`
Result interface{} `json:"result,omitempty"`
Expand Down Expand Up @@ -108,7 +134,7 @@ func (s *RPCServer) register(namespace string, r interface{}) {

// Handle

type rpcErrFunc func(w func(func(io.Writer)), req *request, code int, err error)
type rpcErrFunc func(w func(func(io.Writer)), req *request, code ErrorCode, err error)
type chanOut func(reflect.Value, int64) error

func (s *RPCServer) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) {
Expand Down Expand Up @@ -186,6 +212,30 @@ func (s *RPCServer) getSpan(ctx context.Context, req request) (context.Context,
return ctx, nil
}

func (s *RPCServer) createError(err error) *respError {
var code ErrorCode = 1
if s.errors != nil {
c, ok := s.errors.byType[reflect.TypeOf(err)]
if ok {
code = c
}
}

out := &respError{
Code: code,
Message: err.(error).Error(),
}

if m, ok := err.(marshalable); ok {
meta, err := m.MarshalJSON()
if err == nil {
out.Meta = meta
}
}

return out
}

func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func(keepCtx bool), chOut chanOut) {
// Not sure if we need to sanitize the incoming req.Method or not.
ctx, span := s.getSpan(ctx, req)
Expand Down Expand Up @@ -278,10 +328,8 @@ func (s *RPCServer) handle(ctx context.Context, req request, w func(func(io.Writ
if err != nil {
log.Warnf("error in RPC call to '%s': %+v", req.Method, err)
stats.Record(ctx, metrics.RPCResponseError.M(1))
resp.Error = &respError{
Code: 1,
Message: err.(error).Error(),
}

resp.Error = s.createError(err.(error))
}
}

Expand Down
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Config struct {
timeout time.Duration

paramEncoders map[reflect.Type]ParamEncoder
errors *Errors

noReconnect bool
proxyConnFactory func(func() (*websocket.Conn, error)) func() (*websocket.Conn, error) // for testing
Expand Down Expand Up @@ -68,3 +69,9 @@ func WithParamEncoder(t interface{}, encoder ParamEncoder) func(c *Config) {
c.paramEncoders[reflect.TypeOf(t).Elem()] = encoder
}
}

func WithErrors(es Errors) func(c *Config) {
return func(c *Config) {
c.errors = &es
}
}
7 changes: 7 additions & 0 deletions options_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type ParamDecoder func(ctx context.Context, json []byte) (reflect.Value, error)
type ServerConfig struct {
paramDecoders map[reflect.Type]ParamDecoder
maxRequestSize int64
errors *Errors
}

type ServerOption func(c *ServerConfig)
Expand All @@ -32,3 +33,9 @@ func WithMaxRequestSize(max int64) ServerOption {
c.maxRequestSize = max
}
}

func WithServerErrors(es Errors) ServerOption {
return func(c *ServerConfig) {
c.errors = &es
}
}
88 changes: 88 additions & 0 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
logging "github.com/ipfs/go-log/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)

func init() {
Expand Down Expand Up @@ -991,3 +992,90 @@ func readerDec(ctx context.Context, rin []byte) (reflect.Value, error) {

return reflect.ValueOf(readerRegistery[id]), nil
}

type ErrSomethingBad struct{}

func (e ErrSomethingBad) Error() string {
return "something bad has happened"
}

type ErrMyErr struct{ str string }

var _ error = ErrSomethingBad{}

func (e *ErrMyErr) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &e.str)
}

func (e *ErrMyErr) MarshalJSON() ([]byte, error) {
return json.Marshal(e.str)
}

func (e *ErrMyErr) Error() string {
return fmt.Sprintf("this happened: %s", e.str)
}

type ErrHandler struct{}

func (h *ErrHandler) Test() error {
return ErrSomethingBad{}
}

func (h *ErrHandler) TestP() error {
return &ErrSomethingBad{}
}

func (h *ErrHandler) TestMy(s string) error {
return &ErrMyErr{
str: s,
}
}

func TestUserError(t *testing.T) {
// setup server

serverHandler := &ErrHandler{}

const (
EBad = iota + FirstUserCode
EBad2
EMy
)

errs := NewErrors()
errs.Register(EBad, new(ErrSomethingBad))
errs.Register(EBad2, new(*ErrSomethingBad))
errs.Register(EMy, new(*ErrMyErr))

rpcServer := NewServer(WithServerErrors(errs))
rpcServer.Register("ErrHandler", serverHandler)

// httptest stuff
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()

// setup client

var client struct {
Test func() error
TestP func() error
TestMy func(s string) error
}
closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ErrHandler", []interface{}{
&client,
}, nil, WithErrors(errs))
require.NoError(t, err)

e := client.Test()
require.True(t, xerrors.Is(e, ErrSomethingBad{}))

e = client.TestP()
require.True(t, xerrors.Is(e, &ErrSomethingBad{}))

e = client.TestMy("some event")
require.Error(t, e)
require.Equal(t, "this happened: some event", e.Error())
require.Equal(t, "this happened: some event", e.(*ErrMyErr).Error())

closer()
}
4 changes: 3 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
// RPCServer provides a jsonrpc 2.0 http server handler
type RPCServer struct {
methods map[string]rpcHandler
errors *Errors

// aliasedMethods contains a map of alias:original method names.
// These are used as fallbacks if a method is not found by the given method name.
Expand All @@ -42,6 +43,7 @@ func NewServer(opts ...ServerOption) *RPCServer {
aliasedMethods: map[string]string{},
paramDecoders: config.paramDecoders,
maxRequestSize: config.maxRequestSize,
errors: config.errors,
}
}

Expand Down Expand Up @@ -91,7 +93,7 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.handleReader(ctx, r.Body, w, rpcError)
}

func rpcError(wf func(func(io.Writer)), req *request, code int, err error) {
func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error) {
log.Errorf("RPC Error: %s", err)
wf(func(w io.Writer) {
if hw, ok := w.(http.ResponseWriter); ok {
Expand Down
6 changes: 4 additions & 2 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ const wsCancel = "xrpc.cancel"
const chValue = "xrpc.ch.val"
const chClose = "xrpc.ch.close"

const eTempWSError = -1111111

type frame struct {
// common
Jsonrpc string `json:"jsonrpc"`
Expand Down Expand Up @@ -451,7 +453,7 @@ func (c *wsConn) closeInFlight() {
ID: id,
Error: &respError{
Message: "handler: websocket connection closed",
Code: 2,
Code: eTempWSError,
},
}
}
Expand Down Expand Up @@ -635,7 +637,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) {
ID: *req.req.ID,
Error: &respError{
Message: "handler: websocket connection closed",
Code: 2,
Code: eTempWSError,
},
}
c.writeLk.Unlock()
Expand Down