-
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #605 from DimaSalakhov/basicAuth
Add basic auth middleware
- Loading branch information
Showing
3 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
This package provides a Basic Authentication middleware. | ||
|
||
It'll try to compare credentials from Authentication request header to a username/password pair in middleware constructor. | ||
|
||
More details about this type of authentication can be found in [Mozilla article](https://developer.mozilla.org/en-US/docs/Web/HTTP/Authentication). | ||
|
||
## Usage | ||
|
||
```go | ||
import httptransport "github.com/go-kit/kit/transport/http" | ||
|
||
httptransport.NewServer( | ||
AuthMiddleware(cfg.auth.user, cfg.auth.password, "Example Realm")(makeUppercaseEndpoint()), | ||
decodeMappingsRequest, | ||
httptransport.EncodeJSONResponse, | ||
httptransport.ServerBefore(httptransport.PopulateRequestContext), | ||
) | ||
``` | ||
|
||
For AuthMiddleware to be able to pick up the Authentication header from an HTTP request we need to pass it through the context with something like ```httptransport.ServerBefore(httptransport.PopulateRequestContext)```. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package basic | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/sha256" | ||
"crypto/subtle" | ||
"encoding/base64" | ||
"fmt" | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/go-kit/kit/endpoint" | ||
httptransport "github.com/go-kit/kit/transport/http" | ||
) | ||
|
||
// AuthError represents an authorization error. | ||
type AuthError struct { | ||
Realm string | ||
} | ||
|
||
// StatusCode is an implementation of the StatusCoder interface in go-kit/http. | ||
func (AuthError) StatusCode() int { | ||
return http.StatusUnauthorized | ||
} | ||
|
||
// Error is an implementation of the Error interface. | ||
func (AuthError) Error() string { | ||
return http.StatusText(http.StatusUnauthorized) | ||
} | ||
|
||
// Headers is an implementation of the Headerer interface in go-kit/http. | ||
func (e AuthError) Headers() http.Header { | ||
return http.Header{ | ||
"Content-Type": []string{"text/plain; charset=utf-8"}, | ||
"X-Content-Type-Options": []string{"nosniff"}, | ||
"WWW-Authenticate": []string{fmt.Sprintf(`Basic realm=%q`, e.Realm)}, | ||
} | ||
} | ||
|
||
// parseBasicAuth parses an HTTP Basic Authentication string. | ||
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ([]byte("Aladdin"), []byte("open sesame"), true). | ||
func parseBasicAuth(auth string) (username, password []byte, ok bool) { | ||
const prefix = "Basic " | ||
if !strings.HasPrefix(auth, prefix) { | ||
return | ||
} | ||
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) | ||
if err != nil { | ||
return | ||
} | ||
|
||
s := bytes.IndexByte(c, ':') | ||
if s < 0 { | ||
return | ||
} | ||
return c[:s], c[s+1:], true | ||
} | ||
|
||
// Returns a hash of a given slice. | ||
func toHashSlice(s []byte) []byte { | ||
hash := sha256.Sum256(s) | ||
return hash[:] | ||
} | ||
|
||
// AuthMiddleware returns a Basic Authentication middleware for a particular user and password. | ||
func AuthMiddleware(requiredUser, requiredPassword, realm string) endpoint.Middleware { | ||
requiredUserBytes := toHashSlice([]byte(requiredUser)) | ||
requiredPasswordBytes := toHashSlice([]byte(requiredPassword)) | ||
|
||
return func(next endpoint.Endpoint) endpoint.Endpoint { | ||
return func(ctx context.Context, request interface{}) (interface{}, error) { | ||
auth, ok := ctx.Value(httptransport.ContextKeyRequestAuthorization).(string) | ||
if !ok { | ||
return nil, AuthError{realm} | ||
} | ||
|
||
givenUser, givenPassword, ok := parseBasicAuth(auth) | ||
if !ok { | ||
return nil, AuthError{realm} | ||
} | ||
|
||
givenUserBytes := toHashSlice(givenUser) | ||
givenPasswordBytes := toHashSlice(givenPassword) | ||
|
||
if subtle.ConstantTimeCompare(givenUserBytes, requiredUserBytes) == 0 || | ||
subtle.ConstantTimeCompare(givenPasswordBytes, requiredPasswordBytes) == 0 { | ||
return nil, AuthError{realm} | ||
} | ||
|
||
return next(ctx, request) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package basic | ||
|
||
import ( | ||
"context" | ||
"encoding/base64" | ||
"fmt" | ||
"testing" | ||
|
||
httptransport "github.com/go-kit/kit/transport/http" | ||
) | ||
|
||
func TestWithBasicAuth(t *testing.T) { | ||
requiredUser := "test-user" | ||
requiredPassword := "test-pass" | ||
realm := "test realm" | ||
|
||
type want struct { | ||
result interface{} | ||
err error | ||
} | ||
tests := []struct { | ||
name string | ||
authHeader interface{} | ||
want want | ||
}{ | ||
{"Isn't valid with nil header", nil, want{nil, AuthError{realm}}}, | ||
{"Isn't valid with non-string header", 42, want{nil, AuthError{realm}}}, | ||
{"Isn't valid without authHeader", "", want{nil, AuthError{realm}}}, | ||
{"Isn't valid for wrong user", makeAuthString("wrong-user", requiredPassword), want{nil, AuthError{realm}}}, | ||
{"Isn't valid for wrong password", makeAuthString(requiredUser, "wrong-password"), want{nil, AuthError{realm}}}, | ||
{"Is valid for correct creds", makeAuthString(requiredUser, requiredPassword), want{true, nil}}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
ctx := context.WithValue(context.TODO(), httptransport.ContextKeyRequestAuthorization, tt.authHeader) | ||
|
||
result, err := AuthMiddleware(requiredUser, requiredPassword, realm)(passedValidation)(ctx, nil) | ||
if result != tt.want.result || err != tt.want.err { | ||
t.Errorf("WithBasicAuth() = result: %v, err: %v, want result: %v, want error: %v", result, err, tt.want.result, tt.want.err) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func makeAuthString(user string, password string) string { | ||
data := []byte(fmt.Sprintf("%s:%s", user, password)) | ||
return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(data)) | ||
} | ||
|
||
func passedValidation(ctx context.Context, request interface{}) (response interface{}, err error) { | ||
return true, nil | ||
} |