diff --git a/transport/http/request_response_funcs.go b/transport/http/request_response_funcs.go index 7622375e3..78ad61e41 100644 --- a/transport/http/request_response_funcs.go +++ b/transport/http/request_response_funcs.go @@ -116,4 +116,13 @@ const ( // ContextKeyRequestXRequestID is populated in the context by // PopulateRequestContext. Its value is r.Header.Get("X-Request-Id"). ContextKeyRequestXRequestID + + // ContextKeyResponseHeaders is populated in the context whenever a + // ServerFinalizerFunc is specified. Its value is of type http.Header, and + // is captured only once the entire response has been written. + ContextKeyResponseHeaders + + // ContextKeyResponseSize is populated in the context whenever a + // ServerFinalizerFunc is specified. Its value is of type int64. + ContextKeyResponseSize ) diff --git a/transport/http/server.go b/transport/http/server.go index 524132816..ab2b22e56 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -88,8 +88,12 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := s.ctx if s.finalizer != nil { - iw := &interceptingWriter{w, http.StatusOK} - defer func() { s.finalizer(ctx, iw.code, r) }() + iw := &interceptingWriter{w, http.StatusOK, 0} + defer func() { + ctx = context.WithValue(ctx, ContextKeyResponseHeaders, iw.Header()) + ctx = context.WithValue(ctx, ContextKeyResponseSize, iw.written) + s.finalizer(ctx, iw.code, r) + }() w = iw } @@ -130,7 +134,9 @@ type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) // ServerFinalizerFunc can be used to perform work at the end of an HTTP // request, after the response has been written to the client. The principal -// intended use is for request logging. +// intended use is for request logging. In addition to the response code +// provided in the function signature, additional response parameters are +// provided in the context under keys with the ContextKeyResponse prefix. type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request) // EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a @@ -200,7 +206,8 @@ type Headerer interface { type interceptingWriter struct { http.ResponseWriter - code int + code int + written int64 } // WriteHeader may not be explicitly called, so care must be taken to @@ -209,3 +216,9 @@ func (w *interceptingWriter) WriteHeader(code int) { w.code = code w.ResponseWriter.WriteHeader(code) } + +func (w *interceptingWriter) Write(p []byte) (int, error) { + n, err := w.ResponseWriter.Write(p) + w.written += int64(n) + return n, err +} diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 8c112da2f..654fbecba 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/go-kit/kit/endpoint" httptransport "github.com/go-kit/kit/transport/http" @@ -93,7 +94,13 @@ func TestServerHappyPath(t *testing.T) { } func TestServerFinalizer(t *testing.T) { - c := make(chan int) + var ( + headerKey = "X-Henlo-Lizer" + headerVal = "Helllo you stinky lizard" + statusCode = http.StatusTeapot + responseBody = "go eat a fly ugly\n" + done = make(chan struct{}) + ) handler := httptransport.NewServer( context.Background(), endpoint.Nop, @@ -101,11 +108,27 @@ func TestServerFinalizer(t *testing.T) { return struct{}{}, nil }, func(_ context.Context, w http.ResponseWriter, _ interface{}) error { - w.WriteHeader(<-c) + w.Header().Set(headerKey, headerVal) + w.WriteHeader(statusCode) + w.Write([]byte(responseBody)) return nil }, - httptransport.ServerFinalizer(func(_ context.Context, code int, _ *http.Request) { - c <- code + httptransport.ServerFinalizer(func(ctx context.Context, code int, _ *http.Request) { + if want, have := statusCode, code; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + + responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header) + if want, have := headerVal, responseHeader.Get(headerKey); want != have { + t.Errorf("%s: want %q, have %q", headerKey, want, have) + } + + responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64) + if want, have := int64(len(responseBody)), responseSize; want != have { + t.Errorf("response size: want %d, have %d", want, have) + } + + close(done) }), ) @@ -113,12 +136,10 @@ func TestServerFinalizer(t *testing.T) { defer server.Close() go http.Get(server.URL) - want := http.StatusTeapot - c <- want // give status code to response encoder - have := <-c // take status code from finalizer - - if want != have { - t.Errorf("want %d, have %d", want, have) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") } }