Skip to content

Commit

Permalink
[query] Allow injecting function for rewriting error response (#3008)
Browse files Browse the repository at this point in the history
  • Loading branch information
vpranckaitis authored Dec 22, 2020
1 parent 5bfb190 commit 971e7b0
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 21 deletions.
16 changes: 0 additions & 16 deletions src/query/api/v1/handler/prom/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,3 @@ func Respond(w http.ResponseWriter, data interface{}, warnings promstorage.Warni
w.WriteHeader(http.StatusOK)
w.Write(b)
}

// Responds with error status code and writes error JSON to response body.
func RespondError(w http.ResponseWriter, err error) {
json := jsoniter.ConfigCompatibleWithStandardLibrary
b, marshalErr := json.Marshal(&response{
Status: statusError,
Error: err.Error(),
})
if marshalErr != nil {
xhttp.WriteError(w, marshalErr)
return
}

w.Header().Set(xhttp.HeaderContentType, xhttp.ContentTypeJSON)
xhttp.WriteError(w, err, xhttp.WithErrorResponse(b))
}
9 changes: 5 additions & 4 deletions src/query/api/v1/handler/prom/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/m3db/m3/src/query/models"
"github.com/m3db/m3/src/query/storage/prometheus"
xerrors "github.com/m3db/m3/src/x/errors"
xhttp "github.com/m3db/m3/src/x/net/http"

"github.com/prometheus/prometheus/promql"
promstorage "github.com/prometheus/prometheus/storage"
Expand Down Expand Up @@ -99,13 +100,13 @@ func (h *readHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

fetchOptions, err := h.hOpts.FetchOptionsBuilder().NewFetchOptions(r)
if err != nil {
RespondError(w, err)
xhttp.WriteError(w, err)
return
}

request, err := native.ParseRequest(ctx, r, h.opts.instant, h.hOpts)
if err != nil {
RespondError(w, err)
xhttp.WriteError(w, err)
return
}

Expand All @@ -129,7 +130,7 @@ func (h *readHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.logger.Error("error creating query",
zap.Error(err), zap.String("query", params.Query),
zap.Bool("instant", h.opts.instant))
RespondError(w, xerrors.NewInvalidParamsError(err))
xhttp.WriteError(w, xerrors.NewInvalidParamsError(err))
return
}
defer qry.Close()
Expand All @@ -139,7 +140,7 @@ func (h *readHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.logger.Error("error executing query",
zap.Error(res.Err), zap.String("query", params.Query),
zap.Bool("instant", h.opts.instant))
RespondError(w, res.Err)
xhttp.WriteError(w, res.Err)
return
}

Expand Down
25 changes: 24 additions & 1 deletion src/x/net/http/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@ import (
"encoding/json"
"errors"
"net/http"
"sync"

xerrors "github.com/m3db/m3/src/x/errors"
)

// ErrorRewriteFn is a function for rewriting response error.
type ErrorRewriteFn func(error) error

var (
errorRewriteFn ErrorRewriteFn = func(err error) error { return err }
errorRewriteFnLock sync.RWMutex
)

// Error is an HTTP JSON error that also sets a return status code.
type Error interface {
// Fulfill error interface.
Expand Down Expand Up @@ -93,6 +102,10 @@ func WriteError(w http.ResponseWriter, err error, opts ...WriteErrorOption) {
fn(&o)
}

errorRewriteFnLock.RLock()
err = errorRewriteFn(err)
errorRewriteFnLock.RUnlock()

statusCode := getStatusCode(err)
if o.response == nil {
w.Header().Set(HeaderContentType, ContentTypeJSON)
Expand All @@ -104,6 +117,16 @@ func WriteError(w http.ResponseWriter, err error, opts ...WriteErrorOption) {
}
}

// SetErrorRewriteFn sets error rewrite function.
func SetErrorRewriteFn(f ErrorRewriteFn) ErrorRewriteFn {
errorRewriteFnLock.Lock()
defer errorRewriteFnLock.Unlock()

res := errorRewriteFn
errorRewriteFn = f
return res
}

func getStatusCode(err error) int {
switch v := err.(type) {
case Error:
Expand All @@ -118,7 +141,7 @@ func getStatusCode(err error) int {
return http.StatusInternalServerError
}

// IsClientError returns true if this error would result in 4xx status code
// IsClientError returns true if this error would result in 4xx status code.
func IsClientError(err error) bool {
code := getStatusCode(err)
return code >= 400 && code < 500
Expand Down
41 changes: 41 additions & 0 deletions src/x/net/http/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,55 @@
package xhttp

import (
"errors"
"fmt"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

xerrors "github.com/m3db/m3/src/x/errors"
)

func TestErrorRewrite(t *testing.T) {
tests := []struct {
name string
err error
expectedStatus int
expectedBody string
}{
{
name: "error that should not be rewritten",
err: errors.New("random error"),
expectedStatus: 500,
expectedBody: `{"status":"error","error":"random error"}`,
},
{
name: "error that should be rewritten",
err: xerrors.NewInvalidParamsError(errors.New("to be rewritten")),
expectedStatus: 500,
expectedBody: `{"status":"error","error":"rewritten error"}`,
},
}

SetErrorRewriteFn(func(err error) error {
if xerrors.IsInvalidParams(err) {
return errors.New("rewritten error")
}
return err
})

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
WriteError(recorder, tt.err)
assert.Equal(t, tt.expectedStatus, recorder.Code)
assert.JSONEq(t, tt.expectedBody, recorder.Body.String())
})
}
}

func TestIsClientError(t *testing.T) {
tests := []struct {
err error
Expand Down

0 comments on commit 971e7b0

Please sign in to comment.