Skip to content

Commit

Permalink
transport/http: provide EncodeJSONResponse
Browse files Browse the repository at this point in the history
With sane defaults.
  • Loading branch information
peterbourgon committed Dec 14, 2016
1 parent 61457e1 commit 520cf55
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
42 changes: 29 additions & 13 deletions transport/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,50 @@ type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter)
// intended use is for request logging.
type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request)

// DefaultErrorEncoder writes the error to the ResponseWriter, by default with
// status code 500, content type of text/plain, and the plain text of the error.
// If the error implements StatusCoder, the provided StatusCode will be used
// instead of 500. If the error implements Headerer, the provided headers will
// be applied to the response writer. If the error implements json.Marshaler,
// and the marshaling succeeds, a content type of application/json and the JSON
// encoded form of the error will be used.
func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
code := http.StatusInternalServerError
if sc, ok := err.(StatusCoder); ok {
// EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a
// JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as
// a sensible default. If the response implements Headerer, the provided headers
// will be applied to the response. If the response implements StatusCoder, the
// provided StatusCode will be used instead of 200.
func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if headerer, ok := response.(Headerer); ok {
for k := range headerer.Headers() {
w.Header().Set(k, headerer.Headers().Get(k))
}
}
code := http.StatusOK
if sc, ok := response.(StatusCoder); ok {
code = sc.StatusCode()
}
w.WriteHeader(code)
return json.NewEncoder(w).Encode(response)
}

// DefaultErrorEncoder writes the error to the ResponseWriter, by default a
// content type of text/plain, a body of the plain text of the error, and a
// status code of 500. If the error implements Headerer, the provided headers
// will be applied to the response. If the error implements json.Marshaler, and
// the marshaling succeeds, a content type of application/json and the JSON
// encoded form of the error will be used. If the error implements StatusCoder,
// the provided StatusCode will be used instead of 500.
func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
contentType, body := "text/plain; charset=utf-8", []byte(err.Error())
if marshaler, ok := err.(json.Marshaler); ok {
if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil {
contentType, body = "application/json; charset=utf-8", jsonBody
}
}

w.Header().Set("Content-Type", contentType)

if headerer, ok := err.(Headerer); ok {
for k := range headerer.Headers() {
w.Header().Set(k, headerer.Headers().Get(k))
}
}

code := http.StatusInternalServerError
if sc, ok := err.(StatusCoder); ok {
code = sc.StatusCode()
}
w.WriteHeader(code)
w.Write(body)
}
Expand Down
39 changes: 37 additions & 2 deletions transport/http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"

"golang.org/x/net/context"
Expand Down Expand Up @@ -122,14 +123,48 @@ func TestServerFinalizer(t *testing.T) {
}
}

type enhancedResponse struct {
Foo string `json:"foo"`
}

func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired }
func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }

func TestEncodeJSONResponse(t *testing.T) {
handler := httptransport.NewServer(
context.Background(),
func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
httptransport.EncodeJSONResponse,
)

server := httptest.NewServer(handler)
defer server.Close()

resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have {
t.Errorf("X-Edward: want %q, have %q", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have {
t.Errorf("Body: want %s, have %s", want, have)
}
}

type enhancedError struct{}

func (e enhancedError) Error() string { return "enhanced error" }
func (e enhancedError) StatusCode() int { return http.StatusTeapot }
func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil }
func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} }

func TestServerSpecialError(t *testing.T) {
func TestEnhancedError(t *testing.T) {
handler := httptransport.NewServer(
context.Background(),
func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} },
Expand All @@ -152,7 +187,7 @@ func TestServerSpecialError(t *testing.T) {
t.Errorf("X-Enhanced: want %q, have %q", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := `{"err":"enhanced"}`, string(buf); want != have {
if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have {
t.Errorf("Body: want %s, have %s", want, have)
}
}
Expand Down

0 comments on commit 520cf55

Please sign in to comment.