From a9fac2091011e4c15c2f32835d7a3173ef7b5ce9 Mon Sep 17 00:00:00 2001 From: Viraj Bhartiya Date: Thu, 24 Oct 2024 17:24:49 +0530 Subject: [PATCH] feat: add optional data field to error returns (#123) Co-authored-by: Aryan Tikarya --- client.go | 2 +- errors.go | 5 + handler.go | 88 +++---------- resp_error_test.go | 306 +++++++++++++++++++++++++++++++++++++++++++++ response.go | 89 +++++++++++++ server.go | 4 +- websocket.go | 8 +- 7 files changed, 424 insertions(+), 78 deletions(-) create mode 100644 resp_error_test.go create mode 100644 response.go diff --git a/client.go b/client.go index ba355ee..bc3dac6 100644 --- a/client.go +++ b/client.go @@ -71,7 +71,7 @@ type clientResponse struct { Jsonrpc string `json:"jsonrpc"` Result json.RawMessage `json:"result"` ID interface{} `json:"id"` - Error *respError `json:"error,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` } type makeChanSink func() (context.Context, func([]byte, bool)) diff --git a/errors.go b/errors.go index 35212b2..cf054da 100644 --- a/errors.go +++ b/errors.go @@ -58,3 +58,8 @@ type marshalable interface { json.Marshaler json.Unmarshaler } + +type RPCErrorCodec interface { + FromJSONRPCError(JSONRPCError) error + ToJSONRPCError() (JSONRPCError, error) +} diff --git a/handler.go b/handler.go index 0ea9056..a154666 100644 --- a/handler.go +++ b/handler.go @@ -65,71 +65,6 @@ type request struct { // Configured by WithMaxRequestSize. const DEFAULT_MAX_REQUEST_SIZE = 100 << 20 // 100 MiB -type respError struct { - Code ErrorCode `json:"code"` - Message string `json:"message"` - Meta json.RawMessage `json:"meta,omitempty"` -} - -func (e *respError) Error() string { - if e.Code >= -32768 && e.Code <= -32000 { - return fmt.Sprintf("RPC error (%d): %s", e.Code, e.Message) - } - return e.Message -} - -var marshalableRT = reflect.TypeOf(new(marshalable)).Elem() - -func (e *respError) val(errors *Errors) reflect.Value { - if errors != nil { - t, ok := errors.byCode[e.Code] - if ok { - var v reflect.Value - if t.Kind() == reflect.Ptr { - v = reflect.New(t.Elem()) - } else { - v = reflect.New(t) - } - if len(e.Meta) > 0 && v.Type().Implements(marshalableRT) { - _ = v.Interface().(marshalable).UnmarshalJSON(e.Meta) - } - if t.Kind() != reflect.Ptr { - v = v.Elem() - } - return v - } - } - - return reflect.ValueOf(e) -} - -type response struct { - Jsonrpc string - Result interface{} - ID interface{} - Error *respError -} - -func (r response) MarshalJSON() ([]byte, error) { - // Custom marshal logic as per JSON-RPC 2.0 spec: - // > `result`: - // > This member is REQUIRED on success. - // > This member MUST NOT exist if there was an error invoking the method. - // - // > `error`: - // > This member is REQUIRED on error. - // > This member MUST NOT exist if there was no error triggered during invocation. - data := make(map[string]interface{}) - data["jsonrpc"] = r.Jsonrpc - data["id"] = r.ID - if r.Error != nil { - data["error"] = r.Error - } else { - data["result"] = r.Result - } - return json.Marshal(data) -} - type handler struct { methods map[string]methodHandler errors *Errors @@ -334,7 +269,7 @@ func (s *handler) getSpan(ctx context.Context, req request) (context.Context, *t return ctx, span } -func (s *handler) createError(err error) *respError { +func (s *handler) createError(err error) *JSONRPCError { var code ErrorCode = 1 if s.errors != nil { c, ok := s.errors.byType[reflect.TypeOf(err)] @@ -343,15 +278,25 @@ func (s *handler) createError(err error) *respError { } } - out := &respError{ + out := &JSONRPCError{ Code: code, Message: err.Error(), } - if m, ok := err.(marshalable); ok { - meta, err := m.MarshalJSON() - if err == nil { + switch m := err.(type) { + case RPCErrorCodec: + o, err := m.ToJSONRPCError() + if err != nil { + log.Errorf("Failed to convert error to JSONRPCError: %w", err) + } else { + out = &o + } + case marshalable: + meta, marshalErr := m.MarshalJSON() + if marshalErr == nil { out.Meta = meta + } else { + log.Errorf("Failed to marshal error metadata: %w", marshalErr) } } @@ -504,7 +449,8 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer log.Warnf("failed to setup channel in RPC call to '%s': %+v", req.Method, err) stats.Record(ctx, metrics.RPCResponseError.M(1)) - resp.Error = &respError{ + + resp.Error = &JSONRPCError{ Code: 1, Message: err.Error(), } diff --git a/resp_error_test.go b/resp_error_test.go new file mode 100644 index 0000000..e5b2bec --- /dev/null +++ b/resp_error_test.go @@ -0,0 +1,306 @@ +package jsonrpc + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +type ComplexData struct { + Foo string `json:"foo"` + Bar int `json:"bar"` +} + +type StaticError struct{} + +func (e *StaticError) Error() string { return "static error" } + +// Define the error types +type SimpleError struct { + Message string +} + +func (e *SimpleError) Error() string { + return e.Message +} + +func (e *SimpleError) FromJSONRPCError(jerr JSONRPCError) error { + e.Message = jerr.Message + return nil +} + +func (e *SimpleError) ToJSONRPCError() (JSONRPCError, error) { + return JSONRPCError{Message: e.Message}, nil +} + +var _ RPCErrorCodec = (*SimpleError)(nil) + +type DataStringError struct { + Message string `json:"message"` + Data string `json:"data"` +} + +func (e *DataStringError) Error() string { + return e.Message +} + +func (e *DataStringError) FromJSONRPCError(jerr JSONRPCError) error { + e.Message = jerr.Message + data, ok := jerr.Data.(string) + if !ok { + return fmt.Errorf("expected string data, got %T", jerr.Data) + } + + e.Data = data + + return nil +} + +func (e *DataStringError) ToJSONRPCError() (JSONRPCError, error) { + return JSONRPCError{Message: e.Message, Data: e.Data}, nil +} + +var _ RPCErrorCodec = (*DataStringError)(nil) + +type DataComplexError struct { + Message string + internalData ComplexData +} + +func (e *DataComplexError) Error() string { + return e.Message +} + +func (e *DataComplexError) FromJSONRPCError(jerr JSONRPCError) error { + e.Message = jerr.Message + data, ok := jerr.Data.(json.RawMessage) + if !ok { + return fmt.Errorf("expected string data, got %T", jerr.Data) + } + + if err := json.Unmarshal(data, &e.internalData); err != nil { + return err + } + return nil +} + +func (e *DataComplexError) ToJSONRPCError() (JSONRPCError, error) { + data, err := json.Marshal(e.internalData) + if err != nil { + return JSONRPCError{}, err + } + return JSONRPCError{Message: e.Message, Data: data}, nil +} + +var _ RPCErrorCodec = (*DataComplexError)(nil) + +type MetaError struct { + Message string + Details string +} + +func (e *MetaError) Error() string { + return e.Message +} + +func (e *MetaError) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Message string `json:"message"` + Details string `json:"details"` + }{ + Message: e.Message, + Details: e.Details, + }) +} + +func (e *MetaError) UnmarshalJSON(data []byte) error { + var temp struct { + Message string `json:"message"` + Details string `json:"details"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + e.Message = temp.Message + e.Details = temp.Details + return nil +} + +type ComplexError struct { + Message string + Data ComplexData + Details string +} + +func (e *ComplexError) Error() string { + return e.Message +} + +func (e *ComplexError) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Message string `json:"message"` + Details string `json:"details"` + Data any `json:"data"` + }{ + Details: e.Details, + Message: e.Message, + Data: e.Data, + }) +} + +func (e *ComplexError) UnmarshalJSON(data []byte) error { + var temp struct { + Message string `json:"message"` + Details string `json:"details"` + Data ComplexData `json:"data"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + e.Details = temp.Details + e.Message = temp.Message + e.Data = temp.Data + return nil +} + +func TestRespErrorVal(t *testing.T) { + // Initialize the Errors struct and register error types + errorsMap := NewErrors() + errorsMap.Register(1000, new(*StaticError)) + errorsMap.Register(1001, new(*SimpleError)) + errorsMap.Register(1002, new(*DataStringError)) + errorsMap.Register(1003, new(*DataComplexError)) + errorsMap.Register(1004, new(*MetaError)) + errorsMap.Register(1005, new(*ComplexError)) + + // Define test cases + testCases := []struct { + name string + respError *JSONRPCError + expectedType interface{} + expectedMessage string + verify func(t *testing.T, err error) + }{ + { + name: "StaticError", + respError: &JSONRPCError{ + Code: 1000, + Message: "this is ignored", + }, + expectedType: &StaticError{}, + expectedMessage: "static error", + }, + { + name: "SimpleError", + respError: &JSONRPCError{ + Code: 1001, + Message: "simple error occurred", + }, + expectedType: &SimpleError{}, + expectedMessage: "simple error occurred", + }, + { + name: "DataStringError", + respError: &JSONRPCError{ + Code: 1002, + Message: "data error occurred", + Data: "additional data", + }, + expectedType: &DataStringError{}, + expectedMessage: "data error occurred", + verify: func(t *testing.T, err error) { + require.IsType(t, &DataStringError{}, err) + require.Equal(t, "data error occurred", err.Error()) + require.Equal(t, "additional data", err.(*DataStringError).Data) + }, + }, + { + name: "DataComplexError", + respError: &JSONRPCError{ + Code: 1003, + Message: "data error occurred", + Data: json.RawMessage(`{"foo":"boop","bar":101}`), + }, + expectedType: &DataComplexError{}, + expectedMessage: "data error occurred", + verify: func(t *testing.T, err error) { + require.Equal(t, ComplexData{Foo: "boop", Bar: 101}, err.(*DataComplexError).internalData) + }, + }, + { + name: "MetaError", + respError: &JSONRPCError{ + Code: 1004, + Message: "meta error occurred", + Meta: func() json.RawMessage { + me := &MetaError{ + Message: "meta error occurred", + Details: "meta details", + } + metaData, _ := me.MarshalJSON() + return metaData + }(), + }, + expectedType: &MetaError{}, + expectedMessage: "meta error occurred", + verify: func(t *testing.T, err error) { + // details will also be included in the error message since it implements the marshable interface + require.Equal(t, "meta details", err.(*MetaError).Details) + }, + }, + { + name: "ComplexError", + respError: &JSONRPCError{ + Code: 1005, + Message: "complex error occurred", + Data: json.RawMessage(`"complex data"`), + Meta: func() json.RawMessage { + ce := &ComplexError{ + Message: "complex error occurred", + Details: "complex details", + Data: ComplexData{Foo: "foo", Bar: 42}, + } + metaData, _ := ce.MarshalJSON() + return metaData + }(), + }, + expectedType: &ComplexError{}, + expectedMessage: "complex error occurred", + verify: func(t *testing.T, err error) { + require.Equal(t, ComplexData{Foo: "foo", Bar: 42}, err.(*ComplexError).Data) + require.Equal(t, "complex details", err.(*ComplexError).Details) + }, + }, + { + name: "UnregisteredError", + respError: &JSONRPCError{ + Code: 9999, + Message: "unregistered error occurred", + Data: json.RawMessage(`"some data"`), + }, + expectedType: &JSONRPCError{}, + expectedMessage: "unregistered error occurred", + verify: func(t *testing.T, err error) { + require.Equal(t, json.RawMessage(`"some data"`), err.(*JSONRPCError).Data) + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + errValue := tc.respError.val(&errorsMap) + errInterface := errValue.Interface() + err, ok := errInterface.(error) + require.True(t, ok, "returned value does not implement error interface") + require.IsType(t, tc.expectedType, err) + require.Equal(t, tc.expectedMessage, err.Error()) + if tc.verify != nil { + tc.verify(t, err) + } + }) + } +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..7bf866d --- /dev/null +++ b/response.go @@ -0,0 +1,89 @@ +package jsonrpc + +import ( + "encoding/json" + "fmt" + "reflect" +) + +type response struct { + Jsonrpc string `json:"jsonrpc"` + Result interface{} `json:"result,omitempty"` + ID interface{} `json:"id"` + Error *JSONRPCError `json:"error,omitempty"` +} + +func (r response) MarshalJSON() ([]byte, error) { + // Custom marshal logic as per JSON-RPC 2.0 spec: + // > `result`: + // > This member is REQUIRED on success. + // > This member MUST NOT exist if there was an error invoking the method. + // + // > `error`: + // > This member is REQUIRED on error. + // > This member MUST NOT exist if there was no error triggered during invocation. + data := map[string]interface{}{ + "jsonrpc": r.Jsonrpc, + "id": r.ID, + } + + if r.Error != nil { + data["error"] = r.Error + } else { + data["result"] = r.Result + } + return json.Marshal(data) +} + +type JSONRPCError struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` + Meta json.RawMessage `json:"meta,omitempty"` + Data interface{} `json:"data,omitempty"` +} + +func (e *JSONRPCError) Error() string { + if e.Code >= -32768 && e.Code <= -32000 { + return fmt.Sprintf("RPC error (%d): %s", e.Code, e.Message) + } + return e.Message +} + +var ( + _ error = (*JSONRPCError)(nil) + marshalableRT = reflect.TypeOf(new(marshalable)).Elem() + errorCodecRT = reflect.TypeOf(new(RPCErrorCodec)).Elem() +) + +func (e *JSONRPCError) val(errors *Errors) reflect.Value { + if errors != nil { + t, ok := errors.byCode[e.Code] + if ok { + var v reflect.Value + if t.Kind() == reflect.Ptr { + v = reflect.New(t.Elem()) + } else { + v = reflect.New(t) + } + + if v.Type().Implements(errorCodecRT) { + if err := v.Interface().(RPCErrorCodec).FromJSONRPCError(*e); err != nil { + log.Errorf("Error converting JSONRPCError to custom error type '%s' (code %d): %w", t.String(), e.Code, err) + return reflect.ValueOf(e) + } + } else if len(e.Meta) > 0 && v.Type().Implements(marshalableRT) { + if err := v.Interface().(marshalable).UnmarshalJSON(e.Meta); err != nil { + log.Errorf("Error unmarshalling error metadata to custom error type '%s' (code %d): %w", t.String(), e.Code, err) + return reflect.ValueOf(e) + } + } + + if t.Kind() != reflect.Ptr { + v = v.Elem() + } + return v + } + } + + return reflect.ValueOf(e) +} diff --git a/server.go b/server.go index cc2586a..4454c85 100644 --- a/server.go +++ b/server.go @@ -155,7 +155,7 @@ func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error) resp := response{ Jsonrpc: "2.0", ID: req.ID, - Error: &respError{ + Error: &JSONRPCError{ Code: code, Message: err.Error(), }, @@ -180,4 +180,4 @@ func (s *RPCServer) AliasMethod(alias, original string) { s.aliasedMethods[alias] = original } -var _ error = &respError{} +var _ error = &JSONRPCError{} diff --git a/websocket.go b/websocket.go index 60a0451..05755d3 100644 --- a/websocket.go +++ b/websocket.go @@ -34,7 +34,7 @@ type frame struct { // response Result json.RawMessage `json:"result,omitempty"` - Error *respError `json:"error,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` } type outChanReg struct { @@ -529,7 +529,7 @@ func (c *wsConn) closeInFlight() { req.ready <- clientResponse{ Jsonrpc: "2.0", ID: id, - Error: &respError{ + Error: &JSONRPCError{ Message: "handler: websocket connection closed", Code: eTempWSError, }, @@ -802,7 +802,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) { req.ready <- clientResponse{ Jsonrpc: "2.0", ID: req.req.ID, - Error: &respError{ + Error: &JSONRPCError{ Message: "handler: websocket connection closed", Code: eTempWSError, }, @@ -824,7 +824,7 @@ func (c *wsConn) handleWsConn(ctx context.Context) { Jsonrpc: "2.0", } if serr != nil { - resp.Error = &respError{ + resp.Error = &JSONRPCError{ Code: eTempWSError, Message: fmt.Sprintf("sendRequest: %s", serr), }