From 7947cd46221a0fd5c67a6d90428770ba086163a0 Mon Sep 17 00:00:00 2001 From: instabledesign Date: Fri, 31 Jan 2020 11:56:34 +0100 Subject: [PATCH] implement enable middleware tripperware (#29) --- README.md | 26 ++++++++ middleware.go | 5 ++ middleware/correlation_id_test.go | 2 +- middleware/enable.go | 14 +++++ middleware/enable_test.go | 89 ++++++++++++++++++++++++++++ middleware/metrics_test.go | 2 +- middleware/skip.go | 2 + middleware/skip_test.go | 3 +- tripperware.go | 5 ++ tripperware/correlation_id_test.go | 2 +- tripperware/enable.go | 14 +++++ tripperware/enable_test.go | 95 ++++++++++++++++++++++++++++++ tripperware/skip.go | 2 + tripperware/skip_test.go | 5 +- 14 files changed, 260 insertions(+), 6 deletions(-) create mode 100644 middleware/enable.go create mode 100644 middleware/enable_test.go create mode 100644 tripperware/enable.go create mode 100644 tripperware/enable_test.go diff --git a/README.md b/README.md index 305458c..75451ac 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Package httpware is a collection of middleware (net/http.Handler wrapper) and tr |**Metrics**|X|X| |**Interceptor**|X|X| |**Skip**|X|X| +|**Enable**|X|X| ## Installation @@ -195,6 +196,31 @@ t1 t2 t3 are tripperwares //A after ``` +## Enable/skip middleware or tripperware + +Some times you need to have a dynamic middleware|tripperware stack + +For example you need to have a middleware activated on debug mode +If Enable false, the middleware will not be added to the middleware stack + +```go + debug := true + stack := httpware.MiddlewareStack( + middleware.Enable(debug, middleware.CorrelationId()), + ) +``` + +You can dynamically skip a middleware with your own rule. +If the callback return true it will skip the execution of targeted middleware + +```go + stack := httpware.MiddlewareStack( + tripperware.Skip(func(request *http.Request) bool { + return request.URL.Path == "/home" + }, middleware.CorrelationId()), + ) +``` + ## AppendIf PrependIf For more convinience to build more complex middleware/tripperware stack you can use the AppendIf/PrependIf on Middleware and Middlewares diff --git a/middleware.go b/middleware.go index b1df357..95db1de 100644 --- a/middleware.go +++ b/middleware.go @@ -2,6 +2,11 @@ package httpware import "net/http" +// NopMiddleware just return given http.Handler +func NopMiddleware(next http.Handler) http.Handler { + return next +} + // Middleware represents an http server middleware // it wraps an http.Handler with another one type Middleware func(http.Handler) http.Handler diff --git a/middleware/correlation_id_test.go b/middleware/correlation_id_test.go index c529dcb..f56d892 100644 --- a/middleware/correlation_id_test.go +++ b/middleware/correlation_id_test.go @@ -44,7 +44,7 @@ func TestCorrelationId(t *testing.T) { // ===================================================================================================================== func ExampleCorrelationId() { - port := ":5001" + port := ":9103" // we recommend to use MiddlewareStack to simplify managing all wanted middlewares // caution middleware order matters stack := httpware.MiddlewareStack( diff --git a/middleware/enable.go b/middleware/enable.go new file mode 100644 index 0000000..6abfe6a --- /dev/null +++ b/middleware/enable.go @@ -0,0 +1,14 @@ +package middleware + +import ( + "github.com/gol4ng/httpware/v2" +) + +// Enable middleware is used to conditionnaly add a middleware to a MiddlewareStack +// See Skip middleware to active a middleware in function of request +func Enable(enable bool, middleware httpware.Middleware) httpware.Middleware { + if enable { + return middleware + } + return httpware.NopMiddleware +} diff --git a/middleware/enable_test.go b/middleware/enable_test.go new file mode 100644 index 0000000..8870844 --- /dev/null +++ b/middleware/enable_test.go @@ -0,0 +1,89 @@ +package middleware_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gol4ng/httpware/v2" + "github.com/gol4ng/httpware/v2/middleware" + "github.com/stretchr/testify/assert" +) + +func TestEnable(t *testing.T) { + tests := []struct { + enable bool + expectedExecuted bool + }{ + { + enable: true, + expectedExecuted: true, + }, + { + enable: false, + expectedExecuted: false, + }, + } + + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + responseWriter := &httptest.ResponseRecorder{} + + handler := http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { + assert.Equal(t, request, req) + writer.WriteHeader(http.StatusOK) + }) + + executed := false + dummyMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + executed = true + next.ServeHTTP(writer, request) + }) + } + + for k, test := range tests { + executed = false + t.Run(fmt.Sprintf("test %d (%v)", k, test), func(t *testing.T) { + middleware.Enable(test.enable, dummyMiddleware)(handler).ServeHTTP(responseWriter, request) + + assert.Equal(t, test.expectedExecuted, executed) + }) + } +} + +// ===================================================================================================================== +// ========================================= EXAMPLES ================================================================== +// ===================================================================================================================== + +func ExampleEnable() { + port := ":9104" + + enableDummyMiddleware := true // or false + dummyMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + request.Header.Set("FakeHeader", "this header is set when not /home url") + next.ServeHTTP(writer, request) + }) + } + stack := httpware.MiddlewareStack( + middleware.Enable(enableDummyMiddleware, dummyMiddleware), + ) + + // create a server in order to show it work + srv := http.NewServeMux() + srv.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { + fmt.Println("server receive request with request:", request.Header.Get("FakeHeader")) + }) + + go func() { + if err := http.ListenAndServe(port, stack.DecorateHandler(srv)); err != nil { + panic(err) + } + }() + + _, _ = http.Get("http://localhost" + port + "/") + + // Output: + //server receive request with request: this header is set when not /home url +} diff --git a/middleware/metrics_test.go b/middleware/metrics_test.go index 21248ad..71ec5cc 100644 --- a/middleware/metrics_test.go +++ b/middleware/metrics_test.go @@ -61,7 +61,7 @@ func TestMetrics(t *testing.T) { // ===================================================================================================================== func ExampleMetrics() { - port := ":5002" + port := ":9101" recorder := prom.NewRecorder(prom.Config{}).RegisterOn(nil) diff --git a/middleware/skip.go b/middleware/skip.go index 17e7222..740b84f 100644 --- a/middleware/skip.go +++ b/middleware/skip.go @@ -7,6 +7,8 @@ import ( "github.com/gol4ng/httpware/v2/skip" ) +// Skip middleware is used to conditionnaly activate a middleware in function of request +// See Enable middleware to conditionnaly add middleware to a stack func Skip(condition skip.Condition, middleware httpware.Middleware) httpware.Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { diff --git a/middleware/skip_test.go b/middleware/skip_test.go index 5bf81ff..30fe8c9 100644 --- a/middleware/skip_test.go +++ b/middleware/skip_test.go @@ -43,6 +43,7 @@ func TestSkip(t *testing.T) { } for k, test := range tests { + executed = false t.Run(fmt.Sprintf("test %d (%v)", k, test), func(t *testing.T) { middleware.Skip(func(request *http.Request) bool { return test.conditionResult @@ -58,7 +59,7 @@ func TestSkip(t *testing.T) { // ===================================================================================================================== func ExampleSkip() { - port := ":9902" + port := ":9102" dummyMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { diff --git a/tripperware.go b/tripperware.go index 2cb987f..d5a06f1 100644 --- a/tripperware.go +++ b/tripperware.go @@ -4,6 +4,11 @@ import ( "net/http" ) +// NopTripperware just return given http.RoundTripper +func NopTripperware(next http.RoundTripper) http.RoundTripper { + return next +} + // RoundTripFunc wraps a func to make it into an http.RoundTripper. Similar to http.HandleFunc. type RoundTripFunc func(*http.Request) (*http.Response, error) diff --git a/tripperware/correlation_id_test.go b/tripperware/correlation_id_test.go index 55db770..5bcf3f4 100644 --- a/tripperware/correlation_id_test.go +++ b/tripperware/correlation_id_test.go @@ -97,7 +97,7 @@ func TestCorrelationIdCustom(t *testing.T) { // ===================================================================================================================== func ExampleCorrelationId() { - port := ":9901" + port := ":9001" // create http client using the tripperwareStack as RoundTripper client := http.Client{ diff --git a/tripperware/enable.go b/tripperware/enable.go new file mode 100644 index 0000000..3b916f8 --- /dev/null +++ b/tripperware/enable.go @@ -0,0 +1,14 @@ +package tripperware + +import ( + "github.com/gol4ng/httpware/v2" +) + +// Enable tripperware is used to conditionnaly add a tripperware to a TripperwareStack +// See Skip tripperware to active a tripperware in function of request +func Enable(enable bool, tripperware httpware.Tripperware) httpware.Tripperware { + if enable { + return tripperware + } + return httpware.NopTripperware +} diff --git a/tripperware/enable_test.go b/tripperware/enable_test.go new file mode 100644 index 0000000..bfa8c89 --- /dev/null +++ b/tripperware/enable_test.go @@ -0,0 +1,95 @@ +package tripperware_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gol4ng/httpware/v2" + "github.com/gol4ng/httpware/v2/mocks" + "github.com/gol4ng/httpware/v2/tripperware" + "github.com/stretchr/testify/assert" +) + +func TestEnable(t *testing.T) { + roundTripperMock := &mocks.RoundTripper{} + req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + resp := &http.Response{ + Status: "OK", + StatusCode: http.StatusOK, + ContentLength: 30, + } + roundTripperMock.On("RoundTrip", req).Return(resp, nil) + + tests := []struct { + enable bool + expectedExecuted bool + }{ + { + enable: true, + expectedExecuted: true, + }, + { + enable: false, + expectedExecuted: false, + }, + } + + executed := false + dummyTripperware := func(next http.RoundTripper) http.RoundTripper { + return httpware.RoundTripFunc(func(request *http.Request) (*http.Response, error) { + executed = true + return next.RoundTrip(request) + }) + } + + for k, test := range tests { + executed = false + t.Run(fmt.Sprintf("test %d (%v)", k, test), func(t *testing.T) { + resp2, err := tripperware.Enable(test.enable, dummyTripperware, )(roundTripperMock).RoundTrip(req) + + assert.Nil(t, err) + assert.Equal(t, resp, resp2) + assert.Equal(t, test.expectedExecuted, executed) + }) + } +} + +// ===================================================================================================================== +// ========================================= EXAMPLES ================================================================== +// ===================================================================================================================== + +func ExampleEnable() { + port := ":9003" + + enableDummyTripperware := true //false + dummyTripperware := func(next http.RoundTripper) http.RoundTripper { + return httpware.RoundTripFunc(func(request *http.Request) (*http.Response, error) { + request.Header.Set("FakeHeader", "this header is set when not /home url") + return next.RoundTrip(request) + }) + } + + // create http client using the tripperwareStack as RoundTripper + client := http.Client{ + Transport: tripperware.Enable(enableDummyTripperware, dummyTripperware), + } + + // create a server in order to show it work + srv := http.NewServeMux() + srv.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { + fmt.Println("server receive request with request:", request.Header.Get("FakeHeader")) + }) + + go func() { + if err := http.ListenAndServe(port, srv); err != nil { + panic(err) + } + }() + + _, _ = client.Get("http://localhost" + port + "/") + + // Output: + //server receive request with request: this header is set when not /home url +} diff --git a/tripperware/skip.go b/tripperware/skip.go index 553674a..4158f20 100644 --- a/tripperware/skip.go +++ b/tripperware/skip.go @@ -7,6 +7,8 @@ import ( "github.com/gol4ng/httpware/v2/skip" ) +// Skip tripperware is used to conditionnaly activate a tripperware in function of request +// See Enable tripperware to conditionnaly add tripperware to a stack func Skip(condition skip.Condition, tripperware httpware.Tripperware) httpware.Tripperware { return func(next http.RoundTripper) http.RoundTripper { return httpware.RoundTripFunc(func(request *http.Request) (*http.Response, error) { diff --git a/tripperware/skip_test.go b/tripperware/skip_test.go index 1778422..5307f6b 100644 --- a/tripperware/skip_test.go +++ b/tripperware/skip_test.go @@ -45,10 +45,11 @@ func TestSkip(t *testing.T) { } for k, test := range tests { + executed = false t.Run(fmt.Sprintf("test %d (%v)", k, test), func(t *testing.T) { resp2, err := tripperware.Skip(func(request *http.Request) bool { return test.conditionResult - }, dummyTripperware, )(roundTripperMock).RoundTrip(req) + }, dummyTripperware)(roundTripperMock).RoundTrip(req) assert.Nil(t, err) assert.Equal(t, resp, resp2) @@ -62,7 +63,7 @@ func TestSkip(t *testing.T) { // ===================================================================================================================== func ExampleSkip() { - port := ":9903" + port := ":9002" dummyTripperware := func(next http.RoundTripper) http.RoundTripper { return httpware.RoundTripFunc(func(request *http.Request) (*http.Response, error) {