Skip to content

Commit

Permalink
Add auth middleware/tripperware (#32)
Browse files Browse the repository at this point in the history
Co-authored-by: Anthony Moutte <instabledesign@gmail.com>
  • Loading branch information
qneyrat and instabledesign authored Jun 12, 2020
1 parent 7947cd4 commit 19dee9e
Show file tree
Hide file tree
Showing 12 changed files with 730 additions and 9 deletions.
11 changes: 11 additions & 0 deletions auth/authenticator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package auth

type Authenticator interface {
Authenticate(Credential) (Credential, error)
}

type AuthenticatorFunc func(Credential) (Credential, error)

func (a AuthenticatorFunc) Authenticate(credential Credential) (Credential, error) {
return a(credential)
}
27 changes: 27 additions & 0 deletions auth/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package auth

import (
"context"
)

var credentialContextKey struct{}

func CredentialToContext(ctx context.Context, credential Credential) context.Context {
return context.WithValue(ctx, credentialContextKey, credential)
}

func CredentialFromContext(ctx context.Context) Credential {
if ctx == nil {
return nil
}
value := ctx.Value(credentialContextKey)
if value == nil {
return nil
}
credential, ok := value.(Credential)
if !ok {
return nil
}

return credential
}
35 changes: 35 additions & 0 deletions auth/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package auth

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_Credential_Context(t *testing.T) {
tests := []struct {
context context.Context
expectedCredential Credential
}{
{
context: nil,
expectedCredential: nil,
},
{
context: context.Background(),
expectedCredential: nil,
},
{
context: CredentialToContext(context.Background(), Credential("my_value")),
expectedCredential: "my_value",
},
}

for i, tt := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
assert.Equal(t, tt.expectedCredential, CredentialFromContext(tt.context))
})
}
}
7 changes: 7 additions & 0 deletions auth/credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package auth

type Credential interface{}

type CredentialProvider func() Credential

type CredentialSetter func(Credential)
41 changes: 41 additions & 0 deletions auth/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package auth

import (
"net/http"
)

const (
AuthorizationHeader = "Authorization"
XAuthorizationHeader = "X-Authorization"
)

func FromHeader(request *http.Request) CredentialProvider {
return func() Credential {
return ExtractFromHeader(request)
}
}

func ExtractFromHeader(request *http.Request) Credential {
if request == nil {
return ""
}

tokenHeader := request.Header.Get(AuthorizationHeader)
if tokenHeader == "" {
tokenHeader = request.Header.Get(XAuthorizationHeader)
}

return tokenHeader
}

func AddHeader(request *http.Request) CredentialSetter {
return func(credential Credential) {
if request == nil {
return
}
if creds, ok := credential.(string); ok {
request.Header.Set(AuthorizationHeader, creds)
request.Header.Set(XAuthorizationHeader, creds)
}
}
}
57 changes: 57 additions & 0 deletions auth/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package auth_test

import (
"fmt"
"net/http"
"testing"

"github.com/gol4ng/httpware/v2/auth"
"github.com/stretchr/testify/assert"
)

func TestFromHeader(t *testing.T) {
tests := []struct {
request *http.Request
expectedCredential string
}{
{
request: nil,
expectedCredential: "",
},
{
request: &http.Request{Header: http.Header{
"Authorization": []string{"foo"},
},},
expectedCredential: "foo",
},
{
request: &http.Request{Header: http.Header{
"X-Authorization": []string{"foo"},
},},
expectedCredential: "foo",
},
{
request: &http.Request{Header: http.Header{
"Authorization": []string{"foo"},
"X-Authorization": []string{"bar"},
},},
expectedCredential: "foo",
},
}

for i, tt := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
assert.Equal(t, auth.Credential(tt.expectedCredential), auth.FromHeader(tt.request)())
})
}
}

func TestAddHeader(t *testing.T) {
req := &http.Request{
Header: make(http.Header),
}

credSetter := auth.AddHeader(req)
credSetter("foo")
assert.Equal(t, "foo", req.Header.Get("Authorization"))
}
107 changes: 107 additions & 0 deletions middleware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package middleware

import (
"context"
"net/http"

"github.com/gol4ng/httpware/v2"
"github.com/gol4ng/httpware/v2/auth"
)

// Authentication middleware delegate the authentication process to the Authenticator
func Authentication(authenticator auth.Authenticator, options ...AuthOption) httpware.Middleware {
config := NewAuthConfig(options...)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
newCtx, err := config.authenticateFunc(config.credentialFinder, authenticator, req)
if err == nil {
config.successMiddleware(next).ServeHTTP(writer, req.WithContext(newCtx))
return
} else if config.errorHandler(err, writer, req) {
return
}

next.ServeHTTP(writer, req.WithContext(newCtx))
})
}
}

type CredentialFinder func(r *http.Request) auth.Credential
type AuthenticateFunc func(credentialFinder CredentialFinder, authenticator auth.Authenticator, req *http.Request) (context.Context, error)
type ErrorHandler func(err error, writer http.ResponseWriter, req *http.Request) bool

// AuthOption defines a interceptor middleware configuration option
type AuthOption func(*AuthConfig)

type AuthConfig struct {
credentialFinder CredentialFinder
authenticateFunc AuthenticateFunc
errorHandler ErrorHandler
successMiddleware httpware.Middleware
}

func (o *AuthConfig) apply(options ...AuthOption) {
for _, option := range options {
option(o)
}
}

func NewAuthConfig(options ...AuthOption) *AuthConfig {
opts := &AuthConfig{
credentialFinder: DefaultCredentialFinder,
authenticateFunc: DefaultAuthFunc,
errorHandler: DefaultErrorHandler,
successMiddleware: httpware.NopMiddleware,
}
opts.apply(options...)
return opts
}

func DefaultCredentialFinder(request *http.Request) auth.Credential {
return auth.FromHeader(request)()
}

func DefaultAuthFunc(credentialFinder CredentialFinder, authenticator auth.Authenticator, request *http.Request) (context.Context, error) {
credential := credentialFinder(request)
if authenticator != nil {
creds, err := authenticator.Authenticate(credential)
if err != nil {
return request.Context(), err
}
credential = creds
}
return auth.CredentialToContext(request.Context(), credential), nil
}

func DefaultErrorHandler(err error, writer http.ResponseWriter, _ *http.Request) bool {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return true
}

// WithCredentialFinder will configure AuthenticateFunc option
func WithCredentialFinder(credentialFinder CredentialFinder) AuthOption {
return func(config *AuthConfig) {
config.credentialFinder = credentialFinder
}
}

// WithAuthenticateFunc will configure AuthenticateFunc option
func WithAuthenticateFunc(authenticateFunc AuthenticateFunc) AuthOption {
return func(config *AuthConfig) {
config.authenticateFunc = authenticateFunc
}
}

// WithErrorHandler will configure ErrorHandler option
func WithErrorHandler(errorHandler ErrorHandler) AuthOption {
return func(config *AuthConfig) {
config.errorHandler = errorHandler
}
}

// WithSuccessMiddleware will configure successMiddleware option
func WithSuccessMiddleware(middleware httpware.Middleware) AuthOption {
return func(config *AuthConfig) {
config.successMiddleware = middleware
}
}
Loading

0 comments on commit 19dee9e

Please sign in to comment.