Skip to content

Commit

Permalink
implement enable middleware tripperware (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
instabledesign authored Jan 31, 2020
1 parent c7e09db commit 7947cd4
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 6 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Package httpware is a collection of middleware (net/http.Handler wrapper) and tr
|**Metrics**|X|X|
|**Interceptor**|X|X|
|**Skip**|X|X|
|**Enable**|X|X|

## Installation

Expand Down Expand Up @@ -195,6 +196,31 @@ t1 t2 t3 are tripperwares
//A after
```

## Enable/skip middleware or tripperware

Some times you need to have a dynamic middleware|tripperware stack

For example you need to have a middleware activated on debug mode
If Enable false, the middleware will not be added to the middleware stack

```go
debug := true
stack := httpware.MiddlewareStack(
middleware.Enable(debug, middleware.CorrelationId()),
)
```

You can dynamically skip a middleware with your own rule.
If the callback return true it will skip the execution of targeted middleware

```go
stack := httpware.MiddlewareStack(
tripperware.Skip(func(request *http.Request) bool {
return request.URL.Path == "/home"
}, middleware.CorrelationId()),
)
```

## AppendIf PrependIf

For more convinience to build more complex middleware/tripperware stack you can use the AppendIf/PrependIf on Middleware and Middlewares
Expand Down
5 changes: 5 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ package httpware

import "net/http"

// NopMiddleware just return given http.Handler
func NopMiddleware(next http.Handler) http.Handler {
return next
}

// Middleware represents an http server middleware
// it wraps an http.Handler with another one
type Middleware func(http.Handler) http.Handler
Expand Down
2 changes: 1 addition & 1 deletion middleware/correlation_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestCorrelationId(t *testing.T) {
// =====================================================================================================================

func ExampleCorrelationId() {
port := ":5001"
port := ":9103"
// we recommend to use MiddlewareStack to simplify managing all wanted middlewares
// caution middleware order matters
stack := httpware.MiddlewareStack(
Expand Down
14 changes: 14 additions & 0 deletions middleware/enable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package middleware

import (
"github.com/gol4ng/httpware/v2"
)

// Enable middleware is used to conditionnaly add a middleware to a MiddlewareStack
// See Skip middleware to active a middleware in function of request
func Enable(enable bool, middleware httpware.Middleware) httpware.Middleware {
if enable {
return middleware
}
return httpware.NopMiddleware
}
89 changes: 89 additions & 0 deletions middleware/enable_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package middleware_test

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/gol4ng/httpware/v2"
"github.com/gol4ng/httpware/v2/middleware"
"github.com/stretchr/testify/assert"
)

func TestEnable(t *testing.T) {
tests := []struct {
enable bool
expectedExecuted bool
}{
{
enable: true,
expectedExecuted: true,
},
{
enable: false,
expectedExecuted: false,
},
}

request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil)
responseWriter := &httptest.ResponseRecorder{}

handler := http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
assert.Equal(t, request, req)
writer.WriteHeader(http.StatusOK)
})

executed := false
dummyMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
executed = true
next.ServeHTTP(writer, request)
})
}

for k, test := range tests {
executed = false
t.Run(fmt.Sprintf("test %d (%v)", k, test), func(t *testing.T) {
middleware.Enable(test.enable, dummyMiddleware)(handler).ServeHTTP(responseWriter, request)

assert.Equal(t, test.expectedExecuted, executed)
})
}
}

// =====================================================================================================================
// ========================================= EXAMPLES ==================================================================
// =====================================================================================================================

func ExampleEnable() {
port := ":9104"

enableDummyMiddleware := true // or false
dummyMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
request.Header.Set("FakeHeader", "this header is set when not /home url")
next.ServeHTTP(writer, request)
})
}
stack := httpware.MiddlewareStack(
middleware.Enable(enableDummyMiddleware, dummyMiddleware),
)

// create a server in order to show it work
srv := http.NewServeMux()
srv.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
fmt.Println("server receive request with request:", request.Header.Get("FakeHeader"))
})

go func() {
if err := http.ListenAndServe(port, stack.DecorateHandler(srv)); err != nil {
panic(err)
}
}()

_, _ = http.Get("http://localhost" + port + "/")

// Output:
//server receive request with request: this header is set when not /home url
}
2 changes: 1 addition & 1 deletion middleware/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestMetrics(t *testing.T) {
// =====================================================================================================================

func ExampleMetrics() {
port := ":5002"
port := ":9101"

recorder := prom.NewRecorder(prom.Config{}).RegisterOn(nil)

Expand Down
2 changes: 2 additions & 0 deletions middleware/skip.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"github.com/gol4ng/httpware/v2/skip"
)

// Skip middleware is used to conditionnaly activate a middleware in function of request
// See Enable middleware to conditionnaly add middleware to a stack
func Skip(condition skip.Condition, middleware httpware.Middleware) httpware.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
Expand Down
3 changes: 2 additions & 1 deletion middleware/skip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func TestSkip(t *testing.T) {
}

for k, test := range tests {
executed = false
t.Run(fmt.Sprintf("test %d (%v)", k, test), func(t *testing.T) {
middleware.Skip(func(request *http.Request) bool {
return test.conditionResult
Expand All @@ -58,7 +59,7 @@ func TestSkip(t *testing.T) {
// =====================================================================================================================

func ExampleSkip() {
port := ":9902"
port := ":9102"

dummyMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
Expand Down
5 changes: 5 additions & 0 deletions tripperware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ import (
"net/http"
)

// NopTripperware just return given http.RoundTripper
func NopTripperware(next http.RoundTripper) http.RoundTripper {
return next
}

// RoundTripFunc wraps a func to make it into an http.RoundTripper. Similar to http.HandleFunc.
type RoundTripFunc func(*http.Request) (*http.Response, error)

Expand Down
2 changes: 1 addition & 1 deletion tripperware/correlation_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestCorrelationIdCustom(t *testing.T) {
// =====================================================================================================================

func ExampleCorrelationId() {
port := ":9901"
port := ":9001"

// create http client using the tripperwareStack as RoundTripper
client := http.Client{
Expand Down
14 changes: 14 additions & 0 deletions tripperware/enable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package tripperware

import (
"github.com/gol4ng/httpware/v2"
)

// Enable tripperware is used to conditionnaly add a tripperware to a TripperwareStack
// See Skip tripperware to active a tripperware in function of request
func Enable(enable bool, tripperware httpware.Tripperware) httpware.Tripperware {
if enable {
return tripperware
}
return httpware.NopTripperware
}
95 changes: 95 additions & 0 deletions tripperware/enable_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package tripperware_test

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/gol4ng/httpware/v2"
"github.com/gol4ng/httpware/v2/mocks"
"github.com/gol4ng/httpware/v2/tripperware"
"github.com/stretchr/testify/assert"
)

func TestEnable(t *testing.T) {
roundTripperMock := &mocks.RoundTripper{}
req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil)
resp := &http.Response{
Status: "OK",
StatusCode: http.StatusOK,
ContentLength: 30,
}
roundTripperMock.On("RoundTrip", req).Return(resp, nil)

tests := []struct {
enable bool
expectedExecuted bool
}{
{
enable: true,
expectedExecuted: true,
},
{
enable: false,
expectedExecuted: false,
},
}

executed := false
dummyTripperware := func(next http.RoundTripper) http.RoundTripper {
return httpware.RoundTripFunc(func(request *http.Request) (*http.Response, error) {
executed = true
return next.RoundTrip(request)
})
}

for k, test := range tests {
executed = false
t.Run(fmt.Sprintf("test %d (%v)", k, test), func(t *testing.T) {
resp2, err := tripperware.Enable(test.enable, dummyTripperware, )(roundTripperMock).RoundTrip(req)

assert.Nil(t, err)
assert.Equal(t, resp, resp2)
assert.Equal(t, test.expectedExecuted, executed)
})
}
}

// =====================================================================================================================
// ========================================= EXAMPLES ==================================================================
// =====================================================================================================================

func ExampleEnable() {
port := ":9003"

enableDummyTripperware := true //false
dummyTripperware := func(next http.RoundTripper) http.RoundTripper {
return httpware.RoundTripFunc(func(request *http.Request) (*http.Response, error) {
request.Header.Set("FakeHeader", "this header is set when not /home url")
return next.RoundTrip(request)
})
}

// create http client using the tripperwareStack as RoundTripper
client := http.Client{
Transport: tripperware.Enable(enableDummyTripperware, dummyTripperware),
}

// create a server in order to show it work
srv := http.NewServeMux()
srv.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
fmt.Println("server receive request with request:", request.Header.Get("FakeHeader"))
})

go func() {
if err := http.ListenAndServe(port, srv); err != nil {
panic(err)
}
}()

_, _ = client.Get("http://localhost" + port + "/")

// Output:
//server receive request with request: this header is set when not /home url
}
2 changes: 2 additions & 0 deletions tripperware/skip.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"github.com/gol4ng/httpware/v2/skip"
)

// Skip tripperware is used to conditionnaly activate a tripperware in function of request
// See Enable tripperware to conditionnaly add tripperware to a stack
func Skip(condition skip.Condition, tripperware httpware.Tripperware) httpware.Tripperware {
return func(next http.RoundTripper) http.RoundTripper {
return httpware.RoundTripFunc(func(request *http.Request) (*http.Response, error) {
Expand Down
5 changes: 3 additions & 2 deletions tripperware/skip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ func TestSkip(t *testing.T) {
}

for k, test := range tests {
executed = false
t.Run(fmt.Sprintf("test %d (%v)", k, test), func(t *testing.T) {
resp2, err := tripperware.Skip(func(request *http.Request) bool {
return test.conditionResult
}, dummyTripperware, )(roundTripperMock).RoundTrip(req)
}, dummyTripperware)(roundTripperMock).RoundTrip(req)

assert.Nil(t, err)
assert.Equal(t, resp, resp2)
Expand All @@ -62,7 +63,7 @@ func TestSkip(t *testing.T) {
// =====================================================================================================================

func ExampleSkip() {
port := ":9903"
port := ":9002"

dummyTripperware := func(next http.RoundTripper) http.RoundTripper {
return httpware.RoundTripFunc(func(request *http.Request) (*http.Response, error) {
Expand Down

0 comments on commit 7947cd4

Please sign in to comment.