Skip to content

Commit

Permalink
Change Append/prepend to pointer receiver and add AppendIf/PrependIf (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
instabledesign authored Jan 28, 2020
1 parent 4c7d3ea commit c7e09db
Show file tree
Hide file tree
Showing 6 changed files with 394 additions and 34 deletions.
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ The `Append` and `Prepend` functions will convert a tripperware into a `[]Trippe

```go
roundTripper := tripperware.CorrelationId()
roundTripper.Prepend(func(t http.RoundTripper) http.RoundTripper {
stack := roundTripper.Prepend(func(t http.RoundTripper) http.RoundTripper {
return httpware.RoundTripFunc(func(req *http.Request) (*http.Response, error) {
fmt.Println("http request headers :", req.Header)
return t.RoundTrip(req)
})
})

client := &http.Client{
Transport:roundTripper,
Transport:stack,
}

_, _ = client.Get("fake-address.foo")
Expand Down Expand Up @@ -194,3 +194,23 @@ t1 t2 t3 are tripperwares
//B after
//A after
```

## AppendIf PrependIf

For more convinience to build more complex middleware/tripperware stack you can use the AppendIf/PrependIf on Middleware and Middlewares

```go
debug := true
stack := httpware.TripperwareStack(
myTripperware(),
)
stack.AppendIf(debug, myOtherTripperware())
stack.PrependIf(debug, myOtherTripperware2())
```

```go
debug := true
myTripper := myTripperware(),
stack := myTripper.AppendIf(debug, myOtherTripperware())
stack.PrependIf(debug, myOtherTripperware2())
```
48 changes: 44 additions & 4 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,29 @@ func (m Middleware) Append(middlewares ...Middleware) Middlewares {
return append([]Middleware{m}, middlewares...)
}

// AppendIf will add given middlewares after existing one if condition=true
// t1.AppendIf(true, t2, t3) => [t1, t2, t3]
// t1.AppendIf(false, t2, t3) => [t1]
// t1.AppendIf(true, t2, t3).DecorateHandler(<yourHandler>) == t1(t2(t3(<yourHandler>)))
func (m Middleware) AppendIf(condition bool, middlewares ...Middleware) Middlewares {
return (&Middlewares{m}).AppendIf(condition, middlewares...)
}

// Prepend will add given middlewares before existing one
// t1.Prepend(t2, t3) => [t2, t3, t1]
// t1.Prepend(t2, t3).DecorateHandler(<yourHandler>) == t2(t3(t1(<yourHandler>)))
func (m Middleware) Prepend(middlewares ...Middleware) Middlewares {
return append(middlewares, m)
}

// PrependIf will add given middlewares before existing one if condition=true
// t1.PrependIf(true, t2, t3) => [t2, t3, t1]
// t1.PrependIf(false, t2, t3) => [t1]
// t1.PrependIf(true, t2, t3).DecorateHandler(<yourHandler>) == t2(t3(t1(<yourHandler>)))
func (m Middleware) PrependIf(condition bool, middlewares ...Middleware) Middlewares {
return (&Middlewares{m}).PrependIf(condition, middlewares...)
}

// [t1, t2, t3].DecorateHandler(<yourHandler>) == t1(t2(t3(<yourHandler>)))
type Middlewares []Middleware

Expand All @@ -40,15 +56,39 @@ func (m Middlewares) DecorateHandlerFunc(handler http.HandlerFunc) http.Handler
// Append will add given middleware after existing one
// [t1, t2].Append(t3, t4) => [t1, t2, t3, t4]
// [t1, t2].Append(t3, t4).DecorateHandler(<yourHandler>) == t1(t2(t3(t4(<yourHandler>))))
func (m Middlewares) Append(middleware ...Middleware) Middlewares {
return append(m, middleware...)
func (m *Middlewares) Append(middleware ...Middleware) Middlewares {
*m = append(*m, middleware...)
return *m
}

// AppendIf will add given middleware after existing one if condition=true
// [t1, t2].AppendIf(true, t3, t4) => [t1, t2, t3, t4]
// [t1, t2].AppendIf(false, t3, t4) => [t1, t2]
// [t1, t2].AppendIf(t3, t4).DecorateHandler(<yourHandler>) == t1(t2(t3(t4(<yourHandler>))))
func (m *Middlewares) AppendIf(condition bool, middleware ...Middleware) Middlewares {
if condition {
*m = append(*m, middleware...)
}
return *m
}

// Prepend will add given middleware before existing one
// [t1, t2].Prepend(t3, t4) => [t3, t4, t1, t2]
// [t1, t2].Prepend(t3, t4).DecorateHandler(<yourHandler>) == t3(t4(t1(t2(<yourHandler>))))
func (m Middlewares) Prepend(middleware ...Middleware) Middlewares {
return append(middleware, m...)
func (m *Middlewares) Prepend(middleware ...Middleware) Middlewares {
*m = append(middleware, *m...)
return *m
}

// PrependIf will add given middleware before existing one if condition=true
// [t1, t2].PrependIf(true, t3, t4) => [t3, t4, t1, t2]
// [t1, t2].PrependIf(false, t3, t4) => [t1, t2]
// [t1, t2].PrependIf(true, t3, t4).DecorateHandler(<yourHandler>) == t3(t4(t1(t2(<yourHandler>))))
func (m *Middlewares) PrependIf(condition bool, middleware ...Middleware) Middlewares {
if condition {
*m = append(middleware, *m...)
}
return *m
}

// MiddlewareStack allows you to stack multiple middleware collection in a specific order
Expand Down
6 changes: 3 additions & 3 deletions middleware/skip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func ExampleSkip() {
// create a server in order to show it work
srv := http.NewServeMux()
srv.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
fmt.Println("server receive request with request:", request.Header.Get("FakeHeader"))
fmt.Printf("server receive request %s with request: %s\n", request.URL.Path, request.Header.Get("FakeHeader"))
})

go func() {
Expand All @@ -88,6 +88,6 @@ func ExampleSkip() {
_, _ = http.Get("http://localhost" + port + "/home")

// Output:
//server receive request with request: this header is set when not /home url
//server receive request with request:
//server receive request / with request: this header is set when not /home url
//server receive request /home with request:
}
155 changes: 145 additions & 10 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ func getMiddleware(t *testing.T, i *int, iBefore int, iAfter int) httpware.Middl
})
}

func getMiddlewareShouldNotBeCalled(t *testing.T) httpware.Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Fail(t, "this middleware should not be called")
h.ServeHTTP(w, r)
})
}
}

func TestMiddleware_Append(t *testing.T) {
req := &http.Request{}
responseWriterMock := &httptest.ResponseRecorder{}
Expand All @@ -44,11 +53,41 @@ func TestMiddleware_Append(t *testing.T) {

middleware := getMiddleware(t, i, 0, 6)

middleware.Append(
stack := middleware.Append(
// the middleware will be add here
getMiddleware(t, i, 1, 5),
getMiddleware(t, i, 2, 4),
).DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
)

stack.DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
}

func TestMiddleware_AppendIf(t *testing.T) {
req := &http.Request{}
responseWriterMock := &httptest.ResponseRecorder{}
responseBody := "fake response"

i := new(int)
*i = 0
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, 1, *i)
*i++
assert.IsType(t, responseWriterMock, w)
assert.Equal(t, req, r)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(responseBody))
})

middleware := getMiddleware(t, i, 0, 2)

stack := middleware.AppendIf(
false,
// the middleware will be add here if condition=true
getMiddlewareShouldNotBeCalled(t),
getMiddlewareShouldNotBeCalled(t),
)

stack.DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
}

func TestMiddleware_Prepend(t *testing.T) {
Expand All @@ -69,11 +108,41 @@ func TestMiddleware_Prepend(t *testing.T) {

middleware := getMiddleware(t, i, 2, 4)

middleware.Prepend(
stack := middleware.Prepend(
getMiddleware(t, i, 0, 6),
getMiddleware(t, i, 1, 5),
// the middleware will be add here
).DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
)

stack.DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
}

func TestMiddleware_PrependIf(t *testing.T) {
req := &http.Request{}
responseWriterMock := &httptest.ResponseRecorder{}
responseBody := "fake response"

i := new(int)
*i = 0
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, 1, *i)
*i++
assert.IsType(t, responseWriterMock, w)
assert.Equal(t, req, r)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(responseBody))
})

middleware := getMiddleware(t, i, 0, 2)

stack := middleware.PrependIf(
false,
getMiddlewareShouldNotBeCalled(t),
getMiddlewareShouldNotBeCalled(t),
// the middleware will be add here if condition=true
)

stack.DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
}

func TestMiddlewares_DecorateHandler(t *testing.T) {
Expand Down Expand Up @@ -142,16 +211,49 @@ func TestMiddlewares_Append(t *testing.T) {
_, _ = w.Write([]byte(responseBody))
})

middlewares := httpware.MiddlewareStack(
stack := httpware.MiddlewareStack(
getMiddleware(t, i, 0, 8),
getMiddleware(t, i, 1, 7),
)

middlewares.Append(
stack.Append(
// the middlewares will be add here
getMiddleware(t, i, 2, 6),
getMiddleware(t, i, 3, 5),
).DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
)

stack.DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
}

func TestMiddlewares_AppendIf(t *testing.T) {
req := &http.Request{}
responseWriterMock := &httptest.ResponseRecorder{}
responseBody := "fake response"

i := new(int)
*i = 0
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, 2, *i)
*i++
assert.IsType(t, responseWriterMock, w)
assert.Equal(t, req, r)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(responseBody))
})

stack := httpware.MiddlewareStack(
getMiddleware(t, i, 0, 4),
getMiddleware(t, i, 1, 3),
)

stack.AppendIf(
false,
// the middlewares will be add here if condition=true
getMiddlewareShouldNotBeCalled(t),
getMiddlewareShouldNotBeCalled(t),
)

stack.DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
}

func TestMiddlewares_Prepend(t *testing.T) {
Expand All @@ -170,16 +272,49 @@ func TestMiddlewares_Prepend(t *testing.T) {
_, _ = w.Write([]byte(responseBody))
})

middlewares := httpware.MiddlewareStack(
stack := httpware.MiddlewareStack(
getMiddleware(t, i, 2, 6),
getMiddleware(t, i, 3, 5),
)

middlewares.Prepend(
stack.Prepend(
getMiddleware(t, i, 0, 8),
getMiddleware(t, i, 1, 7),
// the middlewares will be add here
).DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
)

stack.DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
}

func TestMiddlewares_PrependIf(t *testing.T) {
req := &http.Request{}
responseWriterMock := &httptest.ResponseRecorder{}
responseBody := "fake response"

i := new(int)
*i = 0
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, 2, *i)
*i++
assert.IsType(t, responseWriterMock, w)
assert.Equal(t, req, r)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(responseBody))
})

stack := httpware.MiddlewareStack(
getMiddleware(t, i, 0, 4),
getMiddleware(t, i, 1, 3),
)

stack.PrependIf(
false,
getMiddlewareShouldNotBeCalled(t),
getMiddlewareShouldNotBeCalled(t),
// the middlewares will be add here if condition=true
)

stack.DecorateHandler(handler).ServeHTTP(responseWriterMock, req)
}

// =====================================================================================================================
Expand Down
Loading

0 comments on commit c7e09db

Please sign in to comment.