diff --git a/.gitignore b/.gitignore index 6e85044..9b60f02 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ # Output of the go coverage tool, specifically when used with LiteIDE *.out + vendor diff --git a/auth/authenticator.go b/auth/authenticator.go index 6bc2555..9d2e55d 100644 --- a/auth/authenticator.go +++ b/auth/authenticator.go @@ -1,11 +1,15 @@ package auth +import ( + "context" +) + type Authenticator interface { - Authenticate(Credential) (Credential, error) + Authenticate(context.Context, Credential) (Credential, error) } -type AuthenticatorFunc func(Credential) (Credential, error) +type AuthenticatorFunc func(context.Context, Credential) (Credential, error) -func (a AuthenticatorFunc) Authenticate(credential Credential) (Credential, error) { - return a(credential) +func (a AuthenticatorFunc) Authenticate(ctx context.Context, credential Credential) (Credential, error) { + return a(ctx, credential) } diff --git a/middleware/auth.go b/middleware/auth.go index 14b2a50..b72cbb3 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -74,17 +74,18 @@ func WithSuccessMiddleware(middleware httpware.Middleware) AuthOption { // NewAuthenticateFunc is an AuthenticateFunc that find, authenticate and hydrate credentials on the request context func NewAuthenticateFunc(authenticator auth.Authenticator, options ...AuthFuncOption) AuthenticateFunc { - config := newAuthFuncConfig(options...) + config := NewAuthFuncConfig(options...) return func(request *http.Request) (*http.Request, error) { + ctx := request.Context() credential := config.credentialFinder(request) if authenticator != nil { - creds, err := authenticator.Authenticate(credential) + creds, err := authenticator.Authenticate(ctx, credential) if err != nil { return request, err } credential = creds } - return request.WithContext(auth.CredentialToContext(request.Context(), credential)), nil + return request.WithContext(auth.CredentialToContext(ctx, credential)), nil } } @@ -101,7 +102,7 @@ func (o *AuthFuncConfig) apply(options ...AuthFuncOption) { } } -func newAuthFuncConfig(options ...AuthFuncOption) *AuthFuncConfig { +func NewAuthFuncConfig(options ...AuthFuncOption) *AuthFuncConfig { opts := &AuthFuncConfig{ credentialFinder: DefaultCredentialFinder, } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 6a9cb90..e46c45f 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -164,7 +164,7 @@ func TestNewAuthenticateFunc(t *testing.T) { request.Header.Set("Authorization", "my_credential") authenticator := &mocks.Authenticator{} - authenticator.On("Authenticate", "my_credential").Return("my_authenticate_credential", nil) + authenticator.On("Authenticate", context.TODO(), "my_credential").Return("my_authenticate_credential", nil) authenticateFunc := middleware.NewAuthenticateFunc(authenticator) @@ -179,7 +179,7 @@ func TestNewAuthenticateFunc_WithCredentialFinder(t *testing.T) { request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil) authenticator := &mocks.Authenticator{} - authenticator.On("Authenticate", "my_credential_finder_value").Return("my_authenticate_credential", nil) + authenticator.On("Authenticate", context.TODO(), "my_credential_finder_value").Return("my_authenticate_credential", nil) authenticateFunc := middleware.NewAuthenticateFunc( authenticator, @@ -201,7 +201,7 @@ func TestNewAuthenticateFunc_Error(t *testing.T) { err := errors.New("my_authenticate_error") authenticator := &mocks.Authenticator{} - authenticator.On("Authenticate", "my_credential").Return("my_authenticate_credential", err) + authenticator.On("Authenticate", context.TODO(), "my_credential").Return("my_authenticate_credential", err) authenticateFunc := middleware.NewAuthenticateFunc(authenticator) diff --git a/mocks/Authenticator.go b/mocks/Authenticator.go index a4a4000..e9f0049 100644 --- a/mocks/Authenticator.go +++ b/mocks/Authenticator.go @@ -1,10 +1,13 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +// Code generated by mockery v2.4.0. DO NOT EDIT. package mocks import ( - "github.com/gol4ng/httpware/v4/auth" - "github.com/stretchr/testify/mock" + context "context" + + auth "github.com/gol4ng/httpware/v4/auth" + + mock "github.com/stretchr/testify/mock" ) // Authenticator is an autogenerated mock type for the Authenticator type @@ -12,13 +15,13 @@ type Authenticator struct { mock.Mock } -// Authenticate provides a mock function with given fields: _a0 -func (_m *Authenticator) Authenticate(_a0 auth.Credential) (auth.Credential, error) { - ret := _m.Called(_a0) +// Authenticate provides a mock function with given fields: _a0, _a1 +func (_m *Authenticator) Authenticate(_a0 context.Context, _a1 auth.Credential) (auth.Credential, error) { + ret := _m.Called(_a0, _a1) var r0 auth.Credential - if rf, ok := ret.Get(0).(func(auth.Credential) auth.Credential); ok { - r0 = rf(_a0) + if rf, ok := ret.Get(0).(func(context.Context, auth.Credential) auth.Credential); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(auth.Credential) @@ -26,8 +29,8 @@ func (_m *Authenticator) Authenticate(_a0 auth.Credential) (auth.Credential, err } var r1 error - if rf, ok := ret.Get(1).(func(auth.Credential) error); ok { - r1 = rf(_a0) + if rf, ok := ret.Get(1).(func(context.Context, auth.Credential) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) }