diff --git a/go.sum b/go.sum index 4b7caef..5b9bc83 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,7 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793 h1:u+LnwYTOOW7Ukr/fppxEb1Nwz0AtPflrblfvUudpo+I= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/interceptor/copy_read_closer.go b/interceptor/copy_read_closer.go new file mode 100644 index 0000000..5c2243d --- /dev/null +++ b/interceptor/copy_read_closer.go @@ -0,0 +1,71 @@ +package interceptor + +import ( + "bytes" + "io" + "io/ioutil" +) + +// io.Reader with Read method reset offset when EOF +type bufReader struct { + buf []byte + off int +} + +func (r *bufReader) Read(p []byte) (n int, err error) { + if r.off == len(r.buf) { + if len(p) == 0 { + return 0, nil + } + r.off = 0 + return 0, io.EOF + } + + n = copy(p, r.buf[r.off:]) + r.off += n + + return n, nil +} + +type copyReadCloser struct { + io.ReadCloser + // write in bytes.Buffer + copyTemp *bytes.Buffer + // read in copy + copy *bufReader +} + +// First read with io.TeeReader +// -> copyBuffered +// / +// src --> output +// Second read after EOF +// copyBuffered --> copy BufReader simple buffer with fix size +// when BufReader is EOF offset is reset to read again +func NewCopyReadCloser(src io.ReadCloser) *copyReadCloser { + buf := &bytes.Buffer{} + tr := ©ReadCloser{ + copyTemp: buf, + } + + tr.ReadCloser = &struct { + io.Reader + io.Closer + }{io.TeeReader(src, buf), src} + + return tr +} + +func (tr *copyReadCloser)Read(p []byte) (n int, err error) { + n, err = tr.ReadCloser.Read(p) + if err == io.EOF { + if tr.copy == nil { + tr.ReadCloser.Close() + tr.copy = &bufReader{buf: tr.copyTemp.Bytes()} + tr.copyTemp.Reset() + tr.ReadCloser = ioutil.NopCloser(tr.copy) + } + } + + return n, err +} diff --git a/middleware/interceptor.go b/middleware/interceptor.go new file mode 100644 index 0000000..407fe70 --- /dev/null +++ b/middleware/interceptor.go @@ -0,0 +1,64 @@ +package middleware + +import ( + "net/http" + + "github.com/gol4ng/httpware/v2" + "github.com/gol4ng/httpware/v2/interceptor" +) + +// Interceptor middleware allow multiple req.Body read and allow to set callback before and after roundtrip +func Interceptor(options ...Option) httpware.Middleware { + config := NewConfig(options...) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { + writerInterceptor := NewResponseWriterInterceptor(writer) + + req.Body = interceptor.NewCopyReadCloser(req.Body) + config.CallbackBefore(writerInterceptor, req) + defer func() { + config.CallbackAfter(writerInterceptor, req) + }() + + next.ServeHTTP(writerInterceptor, req) + }) + } +} + +type Config struct { + CallbackBefore func(*ResponseWriterInterceptor, *http.Request) + CallbackAfter func(*ResponseWriterInterceptor, *http.Request) +} + +func (c *Config) apply(options ...Option) *Config { + for _, option := range options { + option(c) + } + return c +} + +// NewConfig returns a new interceptor middleware configuration with all options applied +func NewConfig(options ...Option) *Config { + config := &Config{ + CallbackBefore: func(_ *ResponseWriterInterceptor, _ *http.Request) {}, + CallbackAfter: func(_ *ResponseWriterInterceptor, _ *http.Request) {}, + } + return config.apply(options...) +} + +// Option defines a interceptor middleware configuration option +type Option func(*Config) + +// WithBefore will configure CallbackBefore interceptor option +func WithBefore(callbackBefore func(*ResponseWriterInterceptor, *http.Request)) Option { + return func(config *Config) { + config.CallbackBefore = callbackBefore + } +} + +// WithAfter will configure CallbackAfter interceptor option +func WithAfter(callbackAfter func(*ResponseWriterInterceptor, *http.Request)) Option { + return func(config *Config) { + config.CallbackAfter = callbackAfter + } +} diff --git a/middleware/interceptor_test.go b/middleware/interceptor_test.go new file mode 100644 index 0000000..84d86c2 --- /dev/null +++ b/middleware/interceptor_test.go @@ -0,0 +1,66 @@ +package middleware_test + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gol4ng/httpware/v2" + "github.com/gol4ng/httpware/v2/middleware" + "github.com/stretchr/testify/assert" +) + +func TestInterceptor(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewReader([]byte("bar"))) + req.Header.Add("X-Interceptor-Request-Header", "interceptor") + + responseWriter := &httptest.ResponseRecorder{} + stack := httpware.MiddlewareStack( + middleware.Interceptor( + middleware.WithBefore(func(responseWriterInterceptor *middleware.ResponseWriterInterceptor, req *http.Request) { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(req.Body) + assert.NoError(t, err) + assert.Equal(t, "bar", buf.String()) + + assert.Equal(t, http.MethodGet, req.Method) + assert.Equal(t, "/foo", req.URL.String()) + + req.Header.Add("X-Interceptor-Request-Header", "interceptor") + responseWriterInterceptor.Header().Add("X-Interceptor-Response-Header1", "interceptor1") + }), + middleware.WithAfter(func(responseWriterInterceptor *middleware.ResponseWriterInterceptor, req *http.Request) { + assert.Equal(t, http.MethodGet, req.Method) + assert.Equal(t, "/foo", req.URL.String()) + assert.Equal(t, "interceptor", req.Header.Get("X-Interceptor-Request-Header")) + + assert.Equal(t, http.StatusAlreadyReported, responseWriterInterceptor.StatusCode) + assert.Equal(t, "foo bar", string(responseWriterInterceptor.Body)) + + assert.Equal(t, "interceptor1", responseWriterInterceptor.Header().Get("X-Interceptor-Response-Header1")) + assert.Equal(t, "interceptor2", responseWriterInterceptor.Header().Get("X-Interceptor-Response-Header2")) + + responseWriterInterceptor.Header().Add("X-Interceptor-Response-Header3", "interceptor3") + }), + ), + ) + + stack.DecorateHandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(req.Body) + assert.NoError(t, err) + assert.Equal(t, "bar", buf.String()) + rw.WriteHeader(http.StatusAlreadyReported) + + _, err = rw.Write([]byte("foo bar")) + assert.NoError(t, err) + assert.Equal(t, "interceptor1", rw.Header().Get("X-Interceptor-Response-Header1")) + + rw.Header().Add("X-Interceptor-Response-Header2", "interceptor2") + }).ServeHTTP(responseWriter, req) + + assert.Equal(t, "interceptor1", responseWriter.Header().Get("X-Interceptor-Response-Header1")) + assert.Equal(t, "interceptor2", responseWriter.Header().Get("X-Interceptor-Response-Header2")) + assert.Equal(t, "interceptor3", responseWriter.Header().Get("X-Interceptor-Response-Header3")) +} diff --git a/middleware/metrics.go b/middleware/metrics.go index f64fe92..60c4188 100644 --- a/middleware/metrics.go +++ b/middleware/metrics.go @@ -23,15 +23,15 @@ func Metrics(recorder metrics.Recorder, options ... metrics.Option) httpware.Mid start := time.Now() defer func() { - code := strconv.Itoa(writerInterceptor.statusCode) + code := strconv.Itoa(writerInterceptor.StatusCode) if !config.SplitStatus { - code = fmt.Sprintf("%dxx", writerInterceptor.statusCode/100) + code = fmt.Sprintf("%dxx", writerInterceptor.StatusCode/100) } config.Recorder.ObserveHTTPRequestDuration(req.Context(), handlerName, time.Since(start), req.Method, code) if config.ObserveResponseSize { - config.Recorder.ObserveHTTPResponseSize(req.Context(), handlerName, int64(writerInterceptor.bytesWritten), req.Method, code) + config.Recorder.ObserveHTTPResponseSize(req.Context(), handlerName, int64(len(writerInterceptor.Body)), req.Method, code) } }() diff --git a/middleware/response_writer_interceptor.go b/middleware/response_writer_interceptor.go index 4aa3e92..01daba9 100644 --- a/middleware/response_writer_interceptor.go +++ b/middleware/response_writer_interceptor.go @@ -1,26 +1,28 @@ package middleware -import "net/http" +import ( + "net/http" +) -type responseWriterInterceptor struct { +type ResponseWriterInterceptor struct { http.ResponseWriter - statusCode int - bytesWritten int + StatusCode int + Body []byte } -func (w *responseWriterInterceptor) WriteHeader(statusCode int) { - w.statusCode = statusCode +func (w *ResponseWriterInterceptor) WriteHeader(statusCode int) { + w.StatusCode = statusCode w.ResponseWriter.WriteHeader(statusCode) } -func (w *responseWriterInterceptor) Write(p []byte) (int, error) { - w.bytesWritten += len(p) +func (w *ResponseWriterInterceptor) Write(p []byte) (int, error) { + w.Body = append(w.Body, p...) return w.ResponseWriter.Write(p) } -func NewResponseWriterInterceptor(writer http.ResponseWriter) *responseWriterInterceptor { - return &responseWriterInterceptor{ - statusCode: http.StatusServiceUnavailable, +func NewResponseWriterInterceptor(writer http.ResponseWriter) *ResponseWriterInterceptor { + return &ResponseWriterInterceptor{ + StatusCode: http.StatusServiceUnavailable, ResponseWriter: writer, } } diff --git a/tripperware/interceptor.go b/tripperware/interceptor.go new file mode 100644 index 0000000..9eafd40 --- /dev/null +++ b/tripperware/interceptor.go @@ -0,0 +1,62 @@ +package tripperware + +import ( + "net/http" + + "github.com/gol4ng/httpware/v2" + "github.com/gol4ng/httpware/v2/interceptor" +) + +// Interceptor tripperware allow multiple req.Body read and allow to set callback before and after roundtrip +func Interceptor(options ...Option) httpware.Tripperware { + config := NewConfig(options...) + return func(next http.RoundTripper) http.RoundTripper { + return httpware.RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { + req.Body = interceptor.NewCopyReadCloser(req.Body) + config.CallbackBefore(req) + defer func() { + config.CallbackAfter(resp, req) + }() + + return next.RoundTrip(req) + }) + } +} + +type Config struct { + CallbackBefore func(*http.Request) + CallbackAfter func(*http.Response, *http.Request) +} + +func (c *Config) apply(options ...Option) *Config { + for _, option := range options { + option(c) + } + return c +} + +// NewConfig returns a new interceptor configuration with all options applied +func NewConfig(options ...Option) *Config { + config := &Config{ + CallbackBefore: func(_ *http.Request) {}, + CallbackAfter: func(_ *http.Response, _ *http.Request) {}, + } + return config.apply(options...) +} + +// Option defines a interceptor tripperware configuration option +type Option func(*Config) + +// WithAfter will configure CallbackAfter interceptor option +func WithBefore(callbackBefore func(*http.Request)) Option { + return func(config *Config) { + config.CallbackBefore = callbackBefore + } +} + +// WithAfter will configure CallbackAfter interceptor option +func WithAfter(callbackAfter func(*http.Response, *http.Request)) Option { + return func(config *Config) { + config.CallbackAfter = callbackAfter + } +} diff --git a/tripperware/interceptor_test.go b/tripperware/interceptor_test.go new file mode 100644 index 0000000..63ad333 --- /dev/null +++ b/tripperware/interceptor_test.go @@ -0,0 +1,83 @@ +package tripperware_test + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/gol4ng/httpware/v2" + "github.com/gol4ng/httpware/v2/mocks" + "github.com/gol4ng/httpware/v2/tripperware" +) + +func TestInterceptor(t *testing.T) { + roundTripperMock := &mocks.RoundTripper{} + req := httptest.NewRequest(http.MethodPost, "http://fake-addr", bytes.NewBufferString("my_fake_body")) + resp := &http.Response{ + Status: "OK", + StatusCode: http.StatusOK, + ContentLength: 30, + } + + roundTripperMock.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(resp, nil).Run(func(args mock.Arguments) { + innerReq := args.Get(0).(*http.Request) + reqData, err := ioutil.ReadAll(innerReq.Body) + assert.Nil(t, err) + assert.Equal(t, "my_fake_body", string(reqData)) + }) + + resp2, err := tripperware.Interceptor( + tripperware.WithBefore(func(request *http.Request) { + reqData, err := ioutil.ReadAll(request.Body) + assert.Nil(t, err) + assert.Equal(t, "my_fake_body", string(reqData)) + }), + tripperware.WithAfter(func(response *http.Response, request *http.Request) { + + }), + )(roundTripperMock).RoundTrip(req) + assert.Nil(t, err) + assert.Equal(t, resp, resp2) + + reqData, err := ioutil.ReadAll(req.Body) + assert.Nil(t, err) + assert.Equal(t, "my_fake_body", string(reqData)) +} + +// ===================================================================================================================== +// ========================================= EXAMPLES ================================================================== +// ===================================================================================================================== + +func ExampleInterceptor() { + // we recommend to use TripperwareStack to simplify managing all wanted tripperware + // caution tripperware order matter + stack := httpware.TripperwareStack( + tripperware.Interceptor( + tripperware.WithBefore(func(request *http.Request) { + reqData, err := ioutil.ReadAll(request.Body) + fmt.Println("before callback", string(reqData), err) + }), + tripperware.WithAfter(func(response *http.Response, request *http.Request) { + reqData, err := ioutil.ReadAll(request.Body) + fmt.Println("after callback", string(reqData), err) + }), + ), + ) + + // create http client using the tripperwareStack as RoundTripper + client := http.Client{ + Transport: stack, + } + + _, _ = client.Post("fake-address.foo", "plain/text", bytes.NewBufferString("my_fake_body")) + + //Output: + //before callback my_fake_body + //after callback my_fake_body +}