Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auth middleware/tripperware #32

Merged
merged 11 commits into from
Jun 12, 2020
7 changes: 7 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package auth

import (
"context"
)

type AuthFunc func(ctx context.Context) (context.Context, error)
instabledesign marked this conversation as resolved.
Show resolved Hide resolved
27 changes: 27 additions & 0 deletions auth/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package auth

import (
"context"
)

var credentialContextKey struct{}

func CredentialToContext(ctx context.Context, credential Credential) context.Context {
return context.WithValue(ctx, credentialContextKey, credential)
}

func CredentialFromContext(ctx context.Context) Credential {
instabledesign marked this conversation as resolved.
Show resolved Hide resolved
if ctx == nil {
return ""
}
value := ctx.Value(credentialContextKey)
if value == nil {
return ""
}
credential, ok := value.(Credential)
if !ok {
return ""
}

return credential
}
39 changes: 39 additions & 0 deletions auth/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package auth

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_Credential_Context(t *testing.T) {
tests := []struct {
context context.Context
expectedCredential string
}{
{
context: nil,
expectedCredential: "",
},
{
context: context.Background(),
expectedCredential: "",
},
{
context: context.WithValue(context.Background(), credentialContextKey, "not a credential"),
expectedCredential: "",
},
{
context: CredentialToContext(context.Background(), Credential("my_value")),
expectedCredential: "my_value",
},
}

for i, tt := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
assert.Equal(t, Credential(tt.expectedCredential), CredentialFromContext(tt.context))
})
}
}
7 changes: 7 additions & 0 deletions auth/credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package auth

type Credential string
instabledesign marked this conversation as resolved.
Show resolved Hide resolved

type CredentialProvider func() Credential

type CredentialSetter func(Credential)
36 changes: 36 additions & 0 deletions auth/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package auth

import (
"net/http"
)

const (
AuthorizationHeader = "Authorization"
XAuthorizationHeader = "X-Authorization"
)

func FromHeader(request *http.Request) CredentialProvider {
return func() Credential {
if request == nil {
return ""
}

tokenHeader := request.Header.Get(AuthorizationHeader)
if tokenHeader == "" {
tokenHeader = request.Header.Get(XAuthorizationHeader)
}

return Credential(tokenHeader)
}
}

func AddHeader(request *http.Request) CredentialSetter {
return func(credential Credential) {
if request == nil {
return
}

request.Header.Set(AuthorizationHeader, string(credential))
request.Header.Set(XAuthorizationHeader, string(credential))
}
}
57 changes: 57 additions & 0 deletions auth/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package auth_test

import (
"fmt"
"net/http"
"testing"

"github.com/gol4ng/httpware/v2/auth"
"github.com/stretchr/testify/assert"
)

func TestFromHeader(t *testing.T) {
tests := []struct {
request *http.Request
expectedCredential string
}{
{
request: nil,
expectedCredential: "",
},
{
request: &http.Request{Header: http.Header{
"Authorization": []string{"foo"},
},},
expectedCredential: "foo",
},
{
request: &http.Request{Header: http.Header{
"X-Authorization": []string{"foo"},
},},
expectedCredential: "foo",
},
{
request: &http.Request{Header: http.Header{
"Authorization": []string{"foo"},
"X-Authorization": []string{"bar"},
},},
expectedCredential: "foo",
},
}

for i, tt := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
assert.Equal(t, auth.Credential(tt.expectedCredential), auth.FromHeader(tt.request)())
})
}
}

func TestAddHeader(t *testing.T) {
req := &http.Request{
Header: make(http.Header),
}

credSetter := auth.AddHeader(req)
credSetter("foo")
assert.Equal(t, "foo", req.Header.Get("Authorization"))
}
73 changes: 73 additions & 0 deletions middleware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package middleware

import (
"context"
"net/http"

"github.com/gol4ng/httpware/v2"
"github.com/gol4ng/httpware/v2/auth"
)

// Authentication middleware delegate the authentication process to a authFunc configured
func Authentication(options ...AuthOption) httpware.Middleware {
config := NewAuthConfig(options...)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
newCtx, err := config.authFunc(req)
if err != nil && config.errorHandler(err, writer, req) {
return
}

next.ServeHTTP(writer, req.WithContext(newCtx))
})
}
}

type authFunc func(req *http.Request) (context.Context, error)
type errorHandler func(err error, writer http.ResponseWriter, req *http.Request) bool

// AuthOption defines a interceptor middleware configuration option
type AuthOption func(*AuthConfig)

type AuthConfig struct {
authFunc authFunc
errorHandler errorHandler
}

func (o *AuthConfig) apply(options ...AuthOption) {
for _, option := range options {
option(o)
}
}

func NewAuthConfig(options ...AuthOption) *AuthConfig {
opts := &AuthConfig{
authFunc: DefaultAuthFunc,
errorHandler: DefaultErrorHandler,
}
opts.apply(options...)
return opts
}

func DefaultAuthFunc(req *http.Request) (context.Context, error) {
return auth.CredentialToContext(req.Context(), auth.FromHeader(req)()), nil
}

func DefaultErrorHandler(err error, writer http.ResponseWriter, _ *http.Request) bool {
http.Error(writer, err.Error(), http.StatusUnauthorized)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://developer.mozilla.org/en-US/docs/Web/HTTP/Authentication

An HTTP response having a 401 statusCode should also return a WWW-Authenticate HTTP header containing the available authorization methods for the requested resource 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We ignore the type of authentication method was try so we cannot return the appropriate WWW-Authenticate value
It's currently a developer side override

return true
}

// WithAuthFunc will configure authFunc option
func WithAuthFunc(authFunc authFunc) AuthOption {
return func(config *AuthConfig) {
config.authFunc = authFunc
}
}

// WithErrorHandler will configure errorHandler option
func WithErrorHandler(errorHandler errorHandler) AuthOption {
return func(config *AuthConfig) {
config.errorHandler = errorHandler
}
}
108 changes: 108 additions & 0 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package middleware_test

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/gol4ng/httpware/v2/auth"
"github.com/gol4ng/httpware/v2/middleware"
"github.com/stretchr/testify/assert"
)

func TestAuthentication_hydrate_header(t *testing.T) {
tests := []struct {
authorizationHeader string
xAuthorizationHeader string
expectedCredential string
}{
{
authorizationHeader: "",
xAuthorizationHeader: "",
expectedCredential: "",
},
{
authorizationHeader: "Foo",
xAuthorizationHeader: "",
expectedCredential: "Foo",
},
{
authorizationHeader: "",
xAuthorizationHeader: "Foo",
expectedCredential: "Foo",
},
{
authorizationHeader: "Foo",
xAuthorizationHeader: "Bar",
expectedCredential: "Foo",
},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s%s", tt.authorizationHeader, tt.xAuthorizationHeader), func(t *testing.T) {
var innerContext context.Context
request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil)
request.Header.Set(auth.AuthorizationHeader, tt.authorizationHeader)
request.Header.Set(auth.XAuthorizationHeader, tt.xAuthorizationHeader)

handlerCalled := false
handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
handlerCalled = true
innerContext = r.Context()
})

middleware.Authentication()(handler).ServeHTTP(nil, request)

assert.True(t, handlerCalled)
assert.Equal(t, auth.Credential(tt.expectedCredential), auth.CredentialFromContext(innerContext))
})
}
}

func TestAuthentication_Unauthorize(t *testing.T) {
var innerContext context.Context
request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil)
recorder := httptest.NewRecorder()

handlerCalled := false
handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
handlerCalled = true
innerContext = r.Context()
})

middleware.Authentication(middleware.WithAuthFunc(func(req *http.Request) (context.Context, error) {
return req.Context(), errors.New("my_authenticate_error")
}))(handler).ServeHTTP(recorder, request)

assert.False(t, handlerCalled)
assert.Equal(t, auth.Credential(""), auth.CredentialFromContext(innerContext))
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
}

func TestAuthentication_Custom_Error_Handler(t *testing.T) {
var innerContext context.Context
request := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil)
recorder := httptest.NewRecorder()

handlerCalled := false
handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
handlerCalled = true
innerContext = r.Context()
})

middleware.Authentication(
middleware.WithAuthFunc(func(req *http.Request) (context.Context, error) {
return req.Context(), errors.New("my_authenticate_error")
}),
middleware.WithErrorHandler(func(err error, writer http.ResponseWriter, req *http.Request) bool {
_, _ = writer.Write([]byte("my_fake_response"))
return true
}),
)(handler).ServeHTTP(recorder, request)

assert.False(t, handlerCalled)
assert.Equal(t, auth.Credential(""), auth.CredentialFromContext(innerContext))
assert.Equal(t, "my_fake_response", recorder.Body.String())
}
17 changes: 17 additions & 0 deletions tripperware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package tripperware

import (
"net/http"

"github.com/gol4ng/httpware/v2"
"github.com/gol4ng/httpware/v2/auth"
)

func AuthenticationForwarder() httpware.Tripperware {
return func(next http.RoundTripper) http.RoundTripper {
return httpware.RoundTripFunc(func(req *http.Request) (*http.Response, error) {
auth.AddHeader(req)(auth.CredentialFromContext(req.Context()))
instabledesign marked this conversation as resolved.
Show resolved Hide resolved
return next.RoundTrip(req)
})
}
}
Loading