diff --git a/src/query/api/v1/handler/prom/common.go b/src/query/api/v1/handler/prom/common.go index 7c4a66a0d6..7ef1d039a1 100644 --- a/src/query/api/v1/handler/prom/common.go +++ b/src/query/api/v1/handler/prom/common.go @@ -89,19 +89,3 @@ func Respond(w http.ResponseWriter, data interface{}, warnings promstorage.Warni w.WriteHeader(http.StatusOK) w.Write(b) } - -// Responds with error status code and writes error JSON to response body. -func RespondError(w http.ResponseWriter, err error) { - json := jsoniter.ConfigCompatibleWithStandardLibrary - b, marshalErr := json.Marshal(&response{ - Status: statusError, - Error: err.Error(), - }) - if marshalErr != nil { - xhttp.WriteError(w, marshalErr) - return - } - - w.Header().Set(xhttp.HeaderContentType, xhttp.ContentTypeJSON) - xhttp.WriteError(w, err, xhttp.WithErrorResponse(b)) -} diff --git a/src/query/api/v1/handler/prom/read.go b/src/query/api/v1/handler/prom/read.go index b69cdecc39..7b89fcbf67 100644 --- a/src/query/api/v1/handler/prom/read.go +++ b/src/query/api/v1/handler/prom/read.go @@ -34,6 +34,7 @@ import ( "github.com/m3db/m3/src/query/models" "github.com/m3db/m3/src/query/storage/prometheus" xerrors "github.com/m3db/m3/src/x/errors" + xhttp "github.com/m3db/m3/src/x/net/http" "github.com/prometheus/prometheus/promql" promstorage "github.com/prometheus/prometheus/storage" @@ -99,13 +100,13 @@ func (h *readHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fetchOptions, err := h.hOpts.FetchOptionsBuilder().NewFetchOptions(r) if err != nil { - RespondError(w, err) + xhttp.WriteError(w, err) return } request, err := native.ParseRequest(ctx, r, h.opts.instant, h.hOpts) if err != nil { - RespondError(w, err) + xhttp.WriteError(w, err) return } @@ -129,7 +130,7 @@ func (h *readHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.logger.Error("error creating query", zap.Error(err), zap.String("query", params.Query), zap.Bool("instant", h.opts.instant)) - RespondError(w, xerrors.NewInvalidParamsError(err)) + xhttp.WriteError(w, xerrors.NewInvalidParamsError(err)) return } defer qry.Close() @@ -139,7 +140,7 @@ func (h *readHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.logger.Error("error executing query", zap.Error(res.Err), zap.String("query", params.Query), zap.Bool("instant", h.opts.instant)) - RespondError(w, res.Err) + xhttp.WriteError(w, res.Err) return } diff --git a/src/x/net/http/errors.go b/src/x/net/http/errors.go index 082fce1b5d..5e10e54e5f 100644 --- a/src/x/net/http/errors.go +++ b/src/x/net/http/errors.go @@ -25,10 +25,19 @@ import ( "encoding/json" "errors" "net/http" + "sync" xerrors "github.com/m3db/m3/src/x/errors" ) +// ErrorRewriteFn is a function for rewriting response error. +type ErrorRewriteFn func(error) error + +var ( + errorRewriteFn ErrorRewriteFn = func(err error) error { return err } + errorRewriteFnLock sync.RWMutex +) + // Error is an HTTP JSON error that also sets a return status code. type Error interface { // Fulfill error interface. @@ -93,6 +102,10 @@ func WriteError(w http.ResponseWriter, err error, opts ...WriteErrorOption) { fn(&o) } + errorRewriteFnLock.RLock() + err = errorRewriteFn(err) + errorRewriteFnLock.RUnlock() + statusCode := getStatusCode(err) if o.response == nil { w.Header().Set(HeaderContentType, ContentTypeJSON) @@ -104,6 +117,16 @@ func WriteError(w http.ResponseWriter, err error, opts ...WriteErrorOption) { } } +// SetErrorRewriteFn sets error rewrite function. +func SetErrorRewriteFn(f ErrorRewriteFn) ErrorRewriteFn { + errorRewriteFnLock.Lock() + defer errorRewriteFnLock.Unlock() + + res := errorRewriteFn + errorRewriteFn = f + return res +} + func getStatusCode(err error) int { switch v := err.(type) { case Error: @@ -118,7 +141,7 @@ func getStatusCode(err error) int { return http.StatusInternalServerError } -// IsClientError returns true if this error would result in 4xx status code +// IsClientError returns true if this error would result in 4xx status code. func IsClientError(err error) bool { code := getStatusCode(err) return code >= 400 && code < 500 diff --git a/src/x/net/http/errors_test.go b/src/x/net/http/errors_test.go index 4bd5ecf46d..f5c0e6fabe 100644 --- a/src/x/net/http/errors_test.go +++ b/src/x/net/http/errors_test.go @@ -21,14 +21,55 @@ package xhttp import ( + "errors" "fmt" + "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" xerrors "github.com/m3db/m3/src/x/errors" ) +func TestErrorRewrite(t *testing.T) { + tests := []struct { + name string + err error + expectedStatus int + expectedBody string + }{ + { + name: "error that should not be rewritten", + err: errors.New("random error"), + expectedStatus: 500, + expectedBody: `{"status":"error","error":"random error"}`, + }, + { + name: "error that should be rewritten", + err: xerrors.NewInvalidParamsError(errors.New("to be rewritten")), + expectedStatus: 500, + expectedBody: `{"status":"error","error":"rewritten error"}`, + }, + } + + SetErrorRewriteFn(func(err error) error { + if xerrors.IsInvalidParams(err) { + return errors.New("rewritten error") + } + return err + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + WriteError(recorder, tt.err) + assert.Equal(t, tt.expectedStatus, recorder.Code) + assert.JSONEq(t, tt.expectedBody, recorder.Body.String()) + }) + } +} + func TestIsClientError(t *testing.T) { tests := []struct { err error