From 19dee9e315d99bb2cb4013e80f8c303dcb76d276 Mon Sep 17 00:00:00 2001 From: Quentin Neyrat Date: Fri, 12 Jun 2020 12:07:03 +0200 Subject: [PATCH] Add auth middleware/tripperware (#32) Co-authored-by: Anthony Moutte --- auth/authenticator.go | 11 ++ auth/context.go | 27 ++++ auth/context_test.go | 35 +++++ auth/credential.go | 7 + auth/http.go | 41 ++++++ auth/http_test.go | 57 ++++++++ middleware/auth.go | 107 ++++++++++++++ middleware/auth_test.go | 291 +++++++++++++++++++++++++++++++++++++++ mocks/Authenticator.go | 36 +++++ tripperware/auth.go | 52 +++++++ tripperware/auth_test.go | 62 +++++++++ tripperware_test.go | 13 +- 12 files changed, 730 insertions(+), 9 deletions(-) create mode 100644 auth/authenticator.go create mode 100644 auth/context.go create mode 100644 auth/context_test.go create mode 100644 auth/credential.go create mode 100644 auth/http.go create mode 100644 auth/http_test.go create mode 100644 middleware/auth.go create mode 100644 middleware/auth_test.go create mode 100644 mocks/Authenticator.go create mode 100644 tripperware/auth.go create mode 100644 tripperware/auth_test.go diff --git a/auth/authenticator.go b/auth/authenticator.go new file mode 100644 index 0000000..6bc2555 --- /dev/null +++ b/auth/authenticator.go @@ -0,0 +1,11 @@ +package auth + +type Authenticator interface { + Authenticate(Credential) (Credential, error) +} + +type AuthenticatorFunc func(Credential) (Credential, error) + +func (a AuthenticatorFunc) Authenticate(credential Credential) (Credential, error) { + return a(credential) +} diff --git a/auth/context.go b/auth/context.go new file mode 100644 index 0000000..e2a278d --- /dev/null +++ b/auth/context.go @@ -0,0 +1,27 @@ +package auth + +import ( + "context" +) + +var credentialContextKey struct{} + +func CredentialToContext(ctx context.Context, credential Credential) context.Context { + return context.WithValue(ctx, credentialContextKey, credential) +} + +func CredentialFromContext(ctx context.Context) Credential { + if ctx == nil { + return nil + } + value := ctx.Value(credentialContextKey) + if value == nil { + return nil + } + credential, ok := value.(Credential) + if !ok { + return nil + } + + return credential +} diff --git a/auth/context_test.go b/auth/context_test.go new file mode 100644 index 0000000..d84fa28 --- /dev/null +++ b/auth/context_test.go @@ -0,0 +1,35 @@ +package auth + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Credential_Context(t *testing.T) { + tests := []struct { + context context.Context + expectedCredential Credential + }{ + { + context: nil, + expectedCredential: nil, + }, + { + context: context.Background(), + expectedCredential: nil, + }, + { + context: CredentialToContext(context.Background(), Credential("my_value")), + expectedCredential: "my_value", + }, + } + + for i, tt := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + assert.Equal(t, tt.expectedCredential, CredentialFromContext(tt.context)) + }) + } +} diff --git a/auth/credential.go b/auth/credential.go new file mode 100644 index 0000000..ffddbd9 --- /dev/null +++ b/auth/credential.go @@ -0,0 +1,7 @@ +package auth + +type Credential interface{} + +type CredentialProvider func() Credential + +type CredentialSetter func(Credential) diff --git a/auth/http.go b/auth/http.go new file mode 100644 index 0000000..b3319c8 --- /dev/null +++ b/auth/http.go @@ -0,0 +1,41 @@ +package auth + +import ( + "net/http" +) + +const ( + AuthorizationHeader = "Authorization" + XAuthorizationHeader = "X-Authorization" +) + +func FromHeader(request *http.Request) CredentialProvider { + return func() Credential { + return ExtractFromHeader(request) + } +} + +func ExtractFromHeader(request *http.Request) Credential { + if request == nil { + return "" + } + + tokenHeader := request.Header.Get(AuthorizationHeader) + if tokenHeader == "" { + tokenHeader = request.Header.Get(XAuthorizationHeader) + } + + return tokenHeader +} + +func AddHeader(request *http.Request) CredentialSetter { + return func(credential Credential) { + if request == nil { + return + } + if creds, ok := credential.(string); ok { + request.Header.Set(AuthorizationHeader, creds) + request.Header.Set(XAuthorizationHeader, creds) + } + } +} diff --git a/auth/http_test.go b/auth/http_test.go new file mode 100644 index 0000000..100e922 --- /dev/null +++ b/auth/http_test.go @@ -0,0 +1,57 @@ +package auth_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/gol4ng/httpware/v2/auth" + "github.com/stretchr/testify/assert" +) + +func TestFromHeader(t *testing.T) { + tests := []struct { + request *http.Request + expectedCredential string + }{ + { + request: nil, + expectedCredential: "", + }, + { + request: &http.Request{Header: http.Header{ + "Authorization": []string{"foo"}, + },}, + expectedCredential: "foo", + }, + { + request: &http.Request{Header: http.Header{ + "X-Authorization": []string{"foo"}, + },}, + expectedCredential: "foo", + }, + { + request: &http.Request{Header: http.Header{ + "Authorization": []string{"foo"}, + "X-Authorization": []string{"bar"}, + },}, + expectedCredential: "foo", + }, + } + + for i, tt := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + assert.Equal(t, auth.Credential(tt.expectedCredential), auth.FromHeader(tt.request)()) + }) + } +} + +func TestAddHeader(t *testing.T) { + req := &http.Request{ + Header: make(http.Header), + } + + credSetter := auth.AddHeader(req) + credSetter("foo") + assert.Equal(t, "foo", req.Header.Get("Authorization")) +} diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..06021bc --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/gol4ng/httpware/v2" + "github.com/gol4ng/httpware/v2/auth" +) + +// Authentication middleware delegate the authentication process to the Authenticator +func Authentication(authenticator auth.Authenticator, options ...AuthOption) httpware.Middleware { + config := NewAuthConfig(options...) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { + newCtx, err := config.authenticateFunc(config.credentialFinder, authenticator, req) + if err == nil { + config.successMiddleware(next).ServeHTTP(writer, req.WithContext(newCtx)) + return + } else if config.errorHandler(err, writer, req) { + return + } + + next.ServeHTTP(writer, req.WithContext(newCtx)) + }) + } +} + +type CredentialFinder func(r *http.Request) auth.Credential +type AuthenticateFunc func(credentialFinder CredentialFinder, authenticator auth.Authenticator, req *http.Request) (context.Context, error) +type ErrorHandler func(err error, writer http.ResponseWriter, req *http.Request) bool + +// AuthOption defines a interceptor middleware configuration option +type AuthOption func(*AuthConfig) + +type AuthConfig struct { + credentialFinder CredentialFinder + authenticateFunc AuthenticateFunc + errorHandler ErrorHandler + successMiddleware httpware.Middleware +} + +func (o *AuthConfig) apply(options ...AuthOption) { + for _, option := range options { + option(o) + } +} + +func NewAuthConfig(options ...AuthOption) *AuthConfig { + opts := &AuthConfig{ + credentialFinder: DefaultCredentialFinder, + authenticateFunc: DefaultAuthFunc, + errorHandler: DefaultErrorHandler, + successMiddleware: httpware.NopMiddleware, + } + opts.apply(options...) + return opts +} + +func DefaultCredentialFinder(request *http.Request) auth.Credential { + return auth.FromHeader(request)() +} + +func DefaultAuthFunc(credentialFinder CredentialFinder, authenticator auth.Authenticator, request *http.Request) (context.Context, error) { + credential := credentialFinder(request) + if authenticator != nil { + creds, err := authenticator.Authenticate(credential) + if err != nil { + return request.Context(), err + } + credential = creds + } + return auth.CredentialToContext(request.Context(), credential), nil +} + +func DefaultErrorHandler(err error, writer http.ResponseWriter, _ *http.Request) bool { + http.Error(writer, err.Error(), http.StatusUnauthorized) + return true +} + +// WithCredentialFinder will configure AuthenticateFunc option +func WithCredentialFinder(credentialFinder CredentialFinder) AuthOption { + return func(config *AuthConfig) { + config.credentialFinder = credentialFinder + } +} + +// WithAuthenticateFunc will configure AuthenticateFunc option +func WithAuthenticateFunc(authenticateFunc AuthenticateFunc) AuthOption { + return func(config *AuthConfig) { + config.authenticateFunc = authenticateFunc + } +} + +// WithErrorHandler will configure ErrorHandler option +func WithErrorHandler(errorHandler ErrorHandler) AuthOption { + return func(config *AuthConfig) { + config.errorHandler = errorHandler + } +} + +// WithSuccessMiddleware will configure successMiddleware option +func WithSuccessMiddleware(middleware httpware.Middleware) AuthOption { + return func(config *AuthConfig) { + config.successMiddleware = middleware + } +} diff --git a/middleware/auth_test.go b/middleware/auth_test.go new file mode 100644 index 0000000..a5595db --- /dev/null +++ b/middleware/auth_test.go @@ -0,0 +1,291 @@ +package middleware_test + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gol4ng/httpware/v2/auth" + "github.com/gol4ng/httpware/v2/middleware" + "github.com/gol4ng/httpware/v2/mocks" + "github.com/stretchr/testify/assert" +) + +func credentialFinderMock(_ *http.Request) auth.Credential { + return "my_credential" +} + +func TestDefaultCredentialFinder(t *testing.T) { + tests := []struct { + authorizationHeader string + xAuthorizationHeader string + expectedCredential string + }{ + { + authorizationHeader: "", + xAuthorizationHeader: "", + expectedCredential: "", + }, + { + authorizationHeader: "Foo", + xAuthorizationHeader: "", + expectedCredential: "Foo", + }, + { + authorizationHeader: "", + xAuthorizationHeader: "Foo", + expectedCredential: "Foo", + }, + { + authorizationHeader: "Foo", + xAuthorizationHeader: "Bar", + expectedCredential: "Foo", + }, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%s%s", tt.authorizationHeader, tt.xAuthorizationHeader), func(t *testing.T) { + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + request.Header.Set(auth.AuthorizationHeader, tt.authorizationHeader) + request.Header.Set(auth.XAuthorizationHeader, tt.xAuthorizationHeader) + + assert.Equal(t, auth.Credential(tt.expectedCredential), middleware.DefaultCredentialFinder(request)) + }) + } +} + +func TestDefaultAuthFunc(t *testing.T) { + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + + nexCtx, err := middleware.DefaultAuthFunc(credentialFinderMock, nil, request) + assert.NoError(t, err) + assert.Equal(t, auth.Credential("my_credential"), auth.CredentialFromContext(nexCtx)) +} + +func TestDefaultAuthFunc_WithAuthenticator(t *testing.T) { + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + + authenticator := &mocks.Authenticator{} + authenticator.On("Authenticate", "my_credential").Return("my_authenticate_credential", nil) + + nexCtx, err := middleware.DefaultAuthFunc(credentialFinderMock, authenticator, request) + assert.NoError(t, err) + assert.Equal(t, auth.Credential("my_authenticate_credential"), auth.CredentialFromContext(nexCtx)) + authenticator.AssertExpectations(t) +} + +func TestDefaultAuthFunc_WithAuthenticator_Error(t *testing.T) { + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + + err := errors.New("my_authenticate_error") + authenticator := &mocks.Authenticator{} + authenticator.On("Authenticate", "my_credential").Return("my_authenticate_credential", err) + + nexCtx, err := middleware.DefaultAuthFunc(credentialFinderMock, authenticator, request) + assert.EqualError(t, err, "my_authenticate_error") + assert.Equal(t, nil, auth.CredentialFromContext(nexCtx)) + authenticator.AssertExpectations(t) +} + +func TestDefaultErrorHandler(t *testing.T) { + request, _ := http.NewRequest("", "", nil) + response := httptest.NewRecorder() + + middleware.DefaultErrorHandler(errors.New("my_fake_error"), response, request) + + assert.Equal(t, http.StatusUnauthorized, response.Code) + assert.Equal(t, "my_fake_error\n", response.Body.String()) +} + +func TestAuthentication(t *testing.T) { + var innerContext context.Context + request, _ := http.NewRequest(http.MethodGet, "http://fake-addr", nil) + + handlerCalled := false + handler := http.HandlerFunc(func(_ http.ResponseWriter, innerRequest *http.Request) { + handlerCalled = true + innerContext = innerRequest.Context() + }) + + authMiddleware := middleware.Authentication(auth.AuthenticatorFunc(func(credential auth.Credential) (auth.Credential, error) { + return "my_allowed_credential", nil + })) + + authMiddleware(handler).ServeHTTP(nil, request) + assert.True(t, handlerCalled) + assert.NotEqual(t, request.Context(), innerContext) + assert.Equal(t, "my_allowed_credential", auth.CredentialFromContext(innerContext)) +} + +func TestAuthenticationWithAuthenticateFunc(t *testing.T) { + var innerContext context.Context + request, _ := http.NewRequest(http.MethodGet, "http://fake-addr", nil) + + handlerCalled := false + handler := http.HandlerFunc(func(_ http.ResponseWriter, innerRequest *http.Request) { + handlerCalled = true + innerContext = innerRequest.Context() + }) + + authMiddleware := middleware.Authentication(nil, middleware.WithAuthenticateFunc(func(_ middleware.CredentialFinder, _ auth.Authenticator, req *http.Request) (context.Context, error) { + return req.Context(), nil + })) + + authMiddleware(handler).ServeHTTP(nil, request) + assert.True(t, handlerCalled) + assert.Equal(t, request.Context(), innerContext) + assert.Equal(t, nil, auth.CredentialFromContext(innerContext)) +} + +func TestAuthenticationWithCredentialFinder(t *testing.T) { + var innerContext context.Context + request, _ := http.NewRequest(http.MethodGet, "http://fake-addr", nil) + + handlerCalled := false + handler := http.HandlerFunc(func(_ http.ResponseWriter, innerRequest *http.Request) { + handlerCalled = true + innerContext = innerRequest.Context() + }) + + authMiddleware := middleware.Authentication( + nil, + middleware.WithCredentialFinder(func(_ *http.Request) auth.Credential { + return "my_custom_credential" + }), + ) + + authMiddleware(handler).ServeHTTP(nil, request) + assert.True(t, handlerCalled) + assert.NotEqual(t, request.Context(), innerContext) + assert.Equal(t, "my_custom_credential", auth.CredentialFromContext(innerContext)) +} + +func TestAuthenticationWithSuccessMiddleware(t *testing.T) { + var innerContext context.Context + request, _ := http.NewRequest(http.MethodGet, "http://fake-addr", nil) + + handlerCalled := false + handler := http.HandlerFunc(func(_ http.ResponseWriter, innerRequest *http.Request) { + handlerCalled = true + innerContext = innerRequest.Context() + }) + + authMiddleware := middleware.Authentication( + nil, + middleware.WithAuthenticateFunc(func(_ middleware.CredentialFinder, _ auth.Authenticator, req *http.Request) (context.Context, error) { + return req.Context(), nil + }), + middleware.WithSuccessMiddleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { + assert.Nil(t, writer) + assert.Equal(t, request, req) + // we not call next handler for example + }) + }), + ) + + authMiddleware(handler).ServeHTTP(nil, request) + assert.False(t, handlerCalled) + assert.NotEqual(t, request.Context(), innerContext) + assert.Equal(t, nil, auth.CredentialFromContext(innerContext)) +} + +func TestAuthentication_Error(t *testing.T) { + var innerErr error + request, _ := http.NewRequest(http.MethodGet, "http://fake-addr", nil) + + handlerCalled := false + handler := http.HandlerFunc(func(_ http.ResponseWriter, innerRequest *http.Request) { + handlerCalled = true + }) + + authMiddleware := middleware.Authentication( + nil, + middleware.WithAuthenticateFunc(func(_ middleware.CredentialFinder, _ auth.Authenticator, req *http.Request) (context.Context, error) { + return req.Context(), errors.New("my_authenticate_error") + }), + middleware.WithErrorHandler(func(err error, _ http.ResponseWriter, _ *http.Request) bool { + innerErr = err + return true + }), + ) + + authMiddleware(handler).ServeHTTP(nil, request) + assert.False(t, handlerCalled) + assert.EqualError(t, innerErr, "my_authenticate_error") +} + +func TestAuthentication_hydrate_header(t *testing.T) { + tests := []struct { + authorizationHeader string + xAuthorizationHeader string + expectedCredential string + }{ + { + authorizationHeader: "", + xAuthorizationHeader: "", + expectedCredential: "", + }, + { + authorizationHeader: "Foo", + xAuthorizationHeader: "", + expectedCredential: "Foo", + }, + { + authorizationHeader: "", + xAuthorizationHeader: "Foo", + expectedCredential: "Foo", + }, + { + authorizationHeader: "Foo", + xAuthorizationHeader: "Bar", + expectedCredential: "Foo", + }, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%s%s", tt.authorizationHeader, tt.xAuthorizationHeader), func(t *testing.T) { + var innerContext context.Context + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + request.Header.Set(auth.AuthorizationHeader, tt.authorizationHeader) + request.Header.Set(auth.XAuthorizationHeader, tt.xAuthorizationHeader) + + handlerCalled := false + handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + handlerCalled = true + innerContext = r.Context() + }) + + authMiddleware := middleware.Authentication(nil) + + authMiddleware(handler).ServeHTTP(nil, request) + + assert.True(t, handlerCalled) + assert.Equal(t, auth.Credential(tt.expectedCredential), auth.CredentialFromContext(innerContext)) + }) + } +} + +func TestAuthentication_Unauthorize(t *testing.T) { + var innerContext context.Context + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + recorder := httptest.NewRecorder() + + handlerCalled := false + handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + handlerCalled = true + innerContext = r.Context() + }) + + authenticator := &mocks.Authenticator{} + authenticator.On("Authenticate", "").Return("my_authenticated_credential", errors.New("my_authenticated_error")) + authMiddleware := middleware.Authentication(authenticator) + + authMiddleware(handler).ServeHTTP(recorder, request) + + assert.False(t, handlerCalled) + assert.Equal(t, nil, auth.CredentialFromContext(innerContext)) + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Equal(t, "my_authenticated_error\n", recorder.Body.String()) +} diff --git a/mocks/Authenticator.go b/mocks/Authenticator.go new file mode 100644 index 0000000..200e774 --- /dev/null +++ b/mocks/Authenticator.go @@ -0,0 +1,36 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import ( + "github.com/gol4ng/httpware/v2/auth" + "github.com/stretchr/testify/mock" +) + +// Authenticator is an autogenerated mock type for the Authenticator type +type Authenticator struct { + mock.Mock +} + +// Authenticate provides a mock function with given fields: _a0 +func (_m *Authenticator) Authenticate(_a0 auth.Credential) (auth.Credential, error) { + ret := _m.Called(_a0) + + var r0 auth.Credential + if rf, ok := ret.Get(0).(func(auth.Credential) auth.Credential); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(auth.Credential) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(auth.Credential) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/tripperware/auth.go b/tripperware/auth.go new file mode 100644 index 0000000..e323df9 --- /dev/null +++ b/tripperware/auth.go @@ -0,0 +1,52 @@ +package tripperware + +import ( + "net/http" + + "github.com/gol4ng/httpware/v2" + "github.com/gol4ng/httpware/v2/auth" +) + +func AuthenticationForwarder(options ...AuthOption) httpware.Tripperware { + config := NewAuthConfig(options...) + return func(next http.RoundTripper) http.RoundTripper { + return httpware.RoundTripFunc(func(req *http.Request) (*http.Response, error) { + config.credentialForwarder(req) + return next.RoundTrip(req) + }) + } +} + +type credentialForwarder func(req *http.Request) + +// AuthOption defines a interceptor tripperware configuration option +type AuthOption func(*AuthConfig) + +type AuthConfig struct { + credentialForwarder credentialForwarder +} + +func (o *AuthConfig) apply(options ...AuthOption) { + for _, option := range options { + option(o) + } +} + +func NewAuthConfig(options ...AuthOption) *AuthConfig { + opts := &AuthConfig{ + credentialForwarder: DefaultCredentialForwarder, + } + opts.apply(options...) + return opts +} + +func DefaultCredentialForwarder(req *http.Request) { + auth.AddHeader(req)(auth.CredentialFromContext(req.Context())) +} + +// WithCredentialForwarder will configure credentialForwarder option +func WithCredentialForwarder(authFunc credentialForwarder) AuthOption { + return func(config *AuthConfig) { + config.credentialForwarder = authFunc + } +} diff --git a/tripperware/auth_test.go b/tripperware/auth_test.go new file mode 100644 index 0000000..3bce064 --- /dev/null +++ b/tripperware/auth_test.go @@ -0,0 +1,62 @@ +package tripperware_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gol4ng/httpware/v2/auth" + "github.com/gol4ng/httpware/v2/mocks" + "github.com/gol4ng/httpware/v2/tripperware" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAuthenticationForwarder(t *testing.T) { + tests := []struct { + context context.Context + expectedAuthorization string + }{ + { + context: context.TODO(), + expectedAuthorization: "", + }, + { + context: auth.CredentialToContext(context.TODO(), "my_credential"), + expectedAuthorization: "my_credential", + }, + } + for i, tt := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + roundTripperMock := &mocks.RoundTripper{} + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + request = request.WithContext(tt.context) + + roundTripperMock.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(nil, nil).Run(func(args mock.Arguments) { + innerReq := args.Get(0).(*http.Request) + assert.Equal(t, tt.expectedAuthorization, innerReq.Header.Get(auth.AuthorizationHeader)) + assert.Equal(t, tt.expectedAuthorization, innerReq.Header.Get(auth.XAuthorizationHeader)) + }) + + _, _ = tripperware.AuthenticationForwarder()(roundTripperMock).RoundTrip(request) + roundTripperMock.AssertExpectations(t) + }) + } +} + +func TestAuthenticationForwarder_CustomCredentialForwarder(t *testing.T) { + roundTripperMock := &mocks.RoundTripper{} + request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) + + roundTripperMock.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(nil, nil).Run(func(args mock.Arguments) { + innerReq := args.Get(0).(*http.Request) + assert.Equal(t, "my-custom-credential", innerReq.Header.Get("my-auth-header")) + }) + + _, _ = tripperware.AuthenticationForwarder(tripperware.WithCredentialForwarder(func(req *http.Request) { + req.Header.Set("my-auth-header", "my-custom-credential") + }))(roundTripperMock).RoundTrip(request) + roundTripperMock.AssertExpectations(t) +} diff --git a/tripperware_test.go b/tripperware_test.go index 0fca450..e6ff83b 100644 --- a/tripperware_test.go +++ b/tripperware_test.go @@ -60,18 +60,13 @@ func TestTripperware_RoundTrip(t *testing.T) { func TestTripperware_DecorateClient(t *testing.T) { req, _ := http.NewRequest("GET", "http://localhost/", nil) - resp := &http.Response{} - - roundTripperMock := &mocks.RoundTripper{} - roundTripperMock.On("RoundTrip", req).Return(resp, nil) + resp := &http.Response{Status: "My_response"} tripperware := httpware.Tripperware(func(tripper http.RoundTripper) http.RoundTripper { assert.Equal(t, http.DefaultTransport, tripper) - return httpware.RoundTripFunc(func(r *http.Request) (*http.Response, error) { - assert.Equal(t, req, r) - // we already check that tripper == http.DefaultTransport - // so we can replace the call with the mocked one - return roundTripperMock.RoundTrip(r) + return httpware.RoundTripFunc(func(innerRequest *http.Request) (*http.Response, error) { + assert.Equal(t, req, innerRequest) + return resp, nil }) })