diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml new file mode 100644 index 0000000..eb251d5 --- /dev/null +++ b/.github/workflows/main.yaml @@ -0,0 +1,17 @@ +name: CI + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: '1.20.x' + - run: go mod download + - run: go install github.com/securego/gosec/v2/cmd/gosec@latest + - run: gosec ./... + - run: go test ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3aff283 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.DS_Store +coverage.out +coverage.html \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9395292 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,13 @@ +{ + "cSpell.words": [ + "AKIATESTACCESSKEY", + "Fprintln", + "gosec", + "Passwordless", + "rolename", + "securego", + "useast", + "uswest", + "YYYYMMDDTHHMMSSZ" + ] +} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..5efcb6b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,11 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [1.0.0] - 2023-08-05 +### Added +- initial implementation \ No newline at end of file diff --git a/MIT-LICENSE.txt b/MIT-LICENSE.txt new file mode 100644 index 0000000..2f79820 --- /dev/null +++ b/MIT-LICENSE.txt @@ -0,0 +1,20 @@ +Copyright (c) 2023 Klemen Kozelj + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..b2d4df5 --- /dev/null +++ b/README.md @@ -0,0 +1,73 @@ +# go-aws-sts-authenticator + +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) ![CI](https://github.com/KlemenKozelj/go-aws-sts-authenticator/actions/workflows/main.yml/badge.svg) ![Vulnerabilities](https://snyk.io/test/github/KlemenKozelj/go-aws-sts-authenticator/badge.svg) + +**go-aws-sts-authenticator** is a lightweight open-source project developed in Go. The primary objective of this project is to simplify API request authentication by reusing AWS IAM and STS services. It was inspired by Mongo's Atlas password-less database authentication mechanism, please see the credit section below for details. + +### Install +``` +go get github.com/klemenkozelj/go-aws-sts-authenticator +``` + +### Usage +![](assets/example-flow.gif) +[Script](example/main.go) for demonstrative execution that showcases how STS parameters retrieval on client its verification on server: +``` +git clone https://github.com/KlemenKozelj/go-aws-sts-authenticator.git +cd go-aws-sts-authenticator/example +go run main.go // AWS IAM ARN = arn:aws:iam::123456789012:user/klemen.kozelj@gmail.com +``` +#### On Client +Client needs to have configured AWS IAM identity (user or role). +```golang +// NewClient creates a new StsAuthenticator client with default AWS environment configuration. +signatureClient, err := client.NewClient(ctx) + +// Generates STS temporary credentials and from them derives authentication parameters. +xAmzDate, authorization, xAmzSecurityToken, awsRegion, err := signatureClient.GetStsParameters(ctx) + +// We can also directly sign http.Request and parameters will be attached to request headers. +err := signatureClient.SignRequest(req) +``` + +#### On Server +The server is calling the public AWS STS (Security Token Service) endpoint with the provided authentication parameters and does not require any AWS configuration. +```golang +// Validates authentication parameters, if successful callers IAM ARN identity is located in awsGetCallerIdentityResponse.GetCallerIdentityResult.Arn +awsIdentity, err := server.StsGetCallerIdentity(xAmzDate, authorization, xAmzSecurityToken, awsRegion) +``` +Validation is most frequently done in the context of an HTTP web server, http.Handler middleware functions are included for this purpose: +```golang +mux := http.NewServeMux() + +// Request is authenticated if callers AWS IAM ARN is equal to arn:aws:iam::1234567890:user/username +stsAuthenticationMiddleware := server.AuthenticateAwsIamIdentity( + server.DefaultGetRequestParameters, // type GetRequestParameters func(r *http.Request) (awsRegion, xAmzDate, xAmzSecurityToken, authorization string, err error) + server.DefaultIsIamIdentityValid("arn:aws:iam::1234567890:user/username"), // type IsIamIdentityValid func(awsIamIdentityArn string) (valid bool) +) + +handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "Hello World") +}) + +mux.Handle("/", stsAuthenticationMiddleware(handler)) +``` +### External Dependencies +- [aws-sdk-go-v2](https://github.com/aws/aws-sdk-go-v2) + +### Potential Improvements +- [ ] improve code testability with mocking aws-sdk-go-v2 +- [ ] tests for client.go +- [ ] exponential back-off retry strategy for server's http client +- [ ] possibility to insert custom server's http client +- [ ] in-code documentation, benchmark and examples + +### Credits + +Mongo's Atlas password-less database authentication mechanism: +- [Set Up Passwordless Authentication with AWS IAM](https://www.mongodb.com/docs/atlas/security/passwordless-authentication/) +- [Using AWS IAM Authentication with MongoDB 4.4 in Atlas to Build Modern Secure Applications](https://www.youtube.com/watch?v=99iV9lCctrU) + +Scanned with: +- [GoSec - Golang Security Checker](https://github.com/securego/gosec) \ No newline at end of file diff --git a/assets/example-flow.gif b/assets/example-flow.gif new file mode 100644 index 0000000..9f60dad Binary files /dev/null and b/assets/example-flow.gif differ diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..430a9f5 --- /dev/null +++ b/client/client.go @@ -0,0 +1,93 @@ +package client + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +const body string = "Action=GetCallerIdentity&Version=2011-06-15" + +var requestBodySha256 string = getSHA256Hash(body) + +type StsAuthenticator struct { + awsRegion string + signer *v4.Signer + stsClient *sts.Client +} + +func NewClient(ctx context.Context) (*StsAuthenticator, error) { + client := StsAuthenticator{ + signer: v4.NewSigner(), + } + configuration, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, err + } + client.awsRegion = configuration.Region + client.stsClient = sts.NewFromConfig(configuration) + return &client, nil +} + +func (c StsAuthenticator) getTemporarilyAwsCredentials(ctx context.Context) (*aws.Credentials, error) { + result, err := c.stsClient.GetSessionToken(ctx, &sts.GetSessionTokenInput{}) + if err != nil { + return nil, err + } + awsCredentials := credentials.NewStaticCredentialsProvider( + *result.Credentials.AccessKeyId, + *result.Credentials.SecretAccessKey, + *result.Credentials.SessionToken, + ) + credentials, err := awsCredentials.Retrieve(ctx) + if err != nil { + return nil, err + } + return &credentials, nil +} + +func (c StsAuthenticator) GetStsParameters(ctx context.Context) (xAmzDate, authorization, xAmzSecurityToken, awsRegion string, err error) { + awsCredentials, err := c.getTemporarilyAwsCredentials(ctx) + if err != nil { + return "", "", "", "", err + } + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://sts.%s.amazonaws.com", c.awsRegion), strings.NewReader(body)) + if err != nil { + return "", "", "", "", err + } + err2 := c.signer.SignHTTP(ctx, *awsCredentials, req, requestBodySha256, "sts", c.awsRegion, time.Now()) + if err2 != nil { + return "", "", "", "", err2 + } + return req.Header.Get("X-Amz-Date"), req.Header.Get("Authorization"), req.Header.Get("X-Amz-Security-Token"), c.awsRegion, nil +} + +func (c StsAuthenticator) SignRequest(req *http.Request) error { + awsCredentials, err := c.getTemporarilyAwsCredentials(req.Context()) + if err != nil { + return err + } + err2 := c.signer.SignHTTP(req.Context(), *awsCredentials, req, requestBodySha256, "sts", c.awsRegion, time.Now()) + if err2 != nil { + return err2 + } + return nil +} + +func getSHA256Hash(input string) string { + hash := sha256.New() + hash.Write([]byte(input)) + hashBytes := hash.Sum(nil) + hashString := hex.EncodeToString(hashBytes) + return hashString +} diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..1b60582 --- /dev/null +++ b/example/main.go @@ -0,0 +1,60 @@ +package main + +import ( + "context" + "fmt" + "net/http" + + "github.com/klemenkozelj/go-aws-http-authenticator/client" + "github.com/klemenkozelj/go-aws-http-authenticator/server" +) + +func main() { + + xAmzDate, authorization, xAmzSecurityToken, awsRegion := clientExample() + + serverExample(xAmzDate, authorization, xAmzSecurityToken, awsRegion) + + mux := http.NewServeMux() + + awsAuth := server.AuthenticateAwsIamIdentity( + server.DefaultGetRequestParameters, + server.DefaultIsIamIdentityValid("arn:aws:iam::1234567890:user/username"), + ) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Your handler is running!") + }) + + mux.Handle("/", awsAuth(handler)) + +} + +func clientExample() (string, string, string, string) { + + ctx := context.Background() + + signatureClient, err := client.NewClient(ctx) + if err != nil { + panic(err) + } + + xAmzDate, authorization, xAmzSecurityToken, awsRegion, err := signatureClient.GetStsParameters(ctx) + if err != nil { + panic(err) + } + + return xAmzDate, authorization, xAmzSecurityToken, awsRegion + +} + +func serverExample(date, authorization, token, region string) { + + identity, err := server.StsGetCallerIdentity("https://sts."+region+".amazonaws.com", date, authorization, token) + if err != nil { + panic(err) + } + + fmt.Println("AWS IAM ARN = ", identity.GetCallerIdentityResult.Arn) + +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..fa8567a --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module github.com/klemenkozelj/go-aws-http-authenticator + +go 1.20 + +require ( + github.com/aws/aws-sdk-go-v2 v1.19.0 + github.com/aws/aws-sdk-go-v2/config v1.18.29 + github.com/aws/aws-sdk-go-v2/credentials v1.13.28 + github.com/aws/aws-sdk-go-v2/service/sts v1.20.0 +) + +require ( + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.35 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.29 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.36 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.29 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.12.13 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.13 // indirect + github.com/aws/smithy-go v1.13.5 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..db5f1d1 --- /dev/null +++ b/go.sum @@ -0,0 +1,33 @@ +github.com/aws/aws-sdk-go-v2 v1.19.0 h1:klAT+y3pGFBU/qVf1uzwttpBbiuozJYWzNLHioyDJ+k= +github.com/aws/aws-sdk-go-v2 v1.19.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2/config v1.18.29 h1:yA+bSSRGhBwWuprG9I4VgxfK//NBLZ/0BGOHiV3f9oM= +github.com/aws/aws-sdk-go-v2/config v1.18.29/go.mod h1:bJT6P8A+KU1qvNMp8aj+/NmaI06Z670dHNoWsrLOgMg= +github.com/aws/aws-sdk-go-v2/credentials v1.13.28 h1:WM9tEHgoOh5ThJZ042UKnSx7TXGSC/bz63X3fsrQL2o= +github.com/aws/aws-sdk-go-v2/credentials v1.13.28/go.mod h1:86BSbSeamnVVdr1hPfBZVN8SXM7KxSAZAvhNxVfi8fU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.5 h1:kP3Me6Fy3vdi+9uHd7YLr6ewPxRL+PU6y15urfTaamU= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.5/go.mod h1:Gj7tm95r+QsDoN2Fhuz/3npQvcZbkEf5mL70n3Xfluc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.35 h1:hMUCiE3Zi5AHrRNGf5j985u0WyqI6r2NULhUfo0N/No= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.35/go.mod h1:ipR5PvpSPqIqL5Mi82BxLnfMkHVbmco8kUwO2xrCi0M= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.29 h1:yOpYx+FTBdpk/g+sBU6Cb1H0U/TLEcYYp66mYqsPpcc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.29/go.mod h1:M/eUABlDbw2uVrdAn+UsI6M727qp2fxkp8K0ejcBDUY= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.36 h1:8r5m1BoAWkn0TDC34lUculryf7nUF25EgIMdjvGCkgo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.36/go.mod h1:Rmw2M1hMVTwiUhjwMoIBFWFJMhvJbct06sSidxInkhY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.29 h1:IiDolu/eLmuB18DRZibj77n1hHQT7z12jnGO7Ze3pLc= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.29/go.mod h1:fDbkK4o7fpPXWn8YAPmTieAMuB9mk/VgvW64uaUqxd4= +github.com/aws/aws-sdk-go-v2/service/sso v1.12.13 h1:sWDv7cMITPcZ21QdreULwxOOAmE05JjEsT6fCDtDA9k= +github.com/aws/aws-sdk-go-v2/service/sso v1.12.13/go.mod h1:DfX0sWuT46KpcqbMhJ9QWtxAIP1VozkDWf8VAkByjYY= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.13 h1:BFubHS/xN5bjl818QaroN6mQdjneYQ+AOx44KNXlyH4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.13/go.mod h1:BzqsVVFduubEmzrVtUFQQIQdFqvUItF8XUq2EnS8Wog= +github.com/aws/aws-sdk-go-v2/service/sts v1.20.0 h1:jKmIOO+dFvCPuIhhM8u0Dy3dtd590n2kEDSYiGHoI98= +github.com/aws/aws-sdk-go-v2/service/sts v1.20.0/go.mod h1:yVGZA1CPkmUhBdA039jXNJJG7/6t+G+EBWmFq23xqnY= +github.com/aws/smithy-go v1.13.5 h1:hgz0X/DX0dGqTYpGALqXJoRKRj5oQ7150i5FdTePzO8= +github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/server/errors.go b/server/errors.go new file mode 100644 index 0000000..24b1415 --- /dev/null +++ b/server/errors.go @@ -0,0 +1,20 @@ +package server + +type ErrorType int + +const ( + AwsStsInvalidParameter ErrorType = iota + AwsStsRequestError + AwsStsServerError + AWSStsServerRejection + AWSStsServerResponse +) + +type errorCustom struct { + Type ErrorType + Err error +} + +func (se errorCustom) Error() string { + return se.Err.Error() +} diff --git a/server/getCallerIdentity.go b/server/getCallerIdentity.go new file mode 100644 index 0000000..8e7ee88 --- /dev/null +++ b/server/getCallerIdentity.go @@ -0,0 +1,64 @@ +package server + +import ( + "encoding/xml" + "fmt" + "net/http" + "strings" +) + +var httpClient *http.Client = &http.Client{} + +type awsGetCallerIdentityResponse struct { + XMLName xml.Name `xml:"GetCallerIdentityResponse"` + GetCallerIdentityResult struct { + Arn string `xml:"Arn"` + UserId string `xml:"UserId"` + Account string `xml:"Account"` + } `xml:"GetCallerIdentityResult"` + ResponseMetadata struct { + RequestId string `xml:"RequestId"` + } `xml:"ResponseMetadata"` +} + +func StsGetCallerIdentity(url, xAmzDate, authorization, xAmzSecurityToken string) (*awsGetCallerIdentityResponse, error) { + + paramNames := []string{"url", "x-amz-date", "authorization", "x-amz-security-token"} + parameters := []string{url, xAmzDate, authorization, xAmzSecurityToken} + for i, param := range parameters { + if param == "" { + return nil, errorCustom{Type: AwsStsInvalidParameter, Err: fmt.Errorf("%s is empty", paramNames[i])} + } + } + + body := "Action=GetCallerIdentity&Version=2011-06-15" + request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + if err != nil { + return nil, errorCustom{Type: AwsStsRequestError, Err: err} + } + + request.Header.Set("X-Amz-Date", xAmzDate) + request.Header.Set("Authorization", authorization) + request.Header.Set("X-Amz-Security-Token", xAmzSecurityToken) + request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + request.Header.Set("Host", "sts.amazonaws.com") + request.Header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + + response, err := httpClient.Do(request) + if err != nil { + return nil, errorCustom{Type: AwsStsServerError, Err: err} + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, errorCustom{Type: AWSStsServerRejection, Err: fmt.Errorf("aws sts server response status: %s", response.Status)} + } + + var responseAwsCallerIdentity awsGetCallerIdentityResponse + err2 := xml.NewDecoder(response.Body).Decode(&responseAwsCallerIdentity) + if err2 != nil { + return nil, errorCustom{Type: AWSStsServerResponse, Err: err2} + } + + return &responseAwsCallerIdentity, nil +} diff --git a/server/getCallerIdentity_test.go b/server/getCallerIdentity_test.go new file mode 100644 index 0000000..be499da --- /dev/null +++ b/server/getCallerIdentity_test.go @@ -0,0 +1,116 @@ +package server + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestStsGetCallerIdentity(t *testing.T) { + + var res *awsGetCallerIdentityResponse + var err error + var errServer errorCustom + + // Sending empty string url parameter + res, err = StsGetCallerIdentity("", "", "", "") + if res != nil { + t.Errorf("Error is expected so response should be nil") + } + errors.As(err, &errServer) + if errServer.Type != AwsStsInvalidParameter { + t.Errorf("Server error type invalid request error was expected because url was empty string") + } + + // Sending empty string date parameter + res, err = StsGetCallerIdentity("url", "", "", "") + if res != nil { + t.Errorf("Error is expected so response should be nil") + } + errors.As(err, &errServer) + if errServer.Type != AwsStsInvalidParameter { + t.Errorf("Server error type invalid request error was expected because date was empty string") + } + + // Sending empty string authorization parameter + res, err = StsGetCallerIdentity("url", "date", "", "") + if res != nil { + t.Errorf("Error is expected so response should be nil") + } + errors.As(err, &errServer) + if errServer.Type != AwsStsInvalidParameter { + t.Errorf("Server error type invalid request error was expected because authorization was empty string") + } + + // Sending empty string token parameter + res, err = StsGetCallerIdentity("url", "date", "authorization", "") + if res != nil { + t.Errorf("Error is expected so response should be nil") + } + errors.As(err, &errServer) + if errServer.Type != AwsStsInvalidParameter { + t.Errorf("Server error type invalid request error was expected because token was empty string") + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + queryParams := r.URL.Query() + switch val := queryParams.Get("case"); val { + case "notAcceptable": + w.WriteHeader(http.StatusNotAcceptable) + case "badResponse": + xmlResponse := "this is not expected xml payload" + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, xmlResponse) + case "goodResponse": + xmlResponse := ` + + + arn:aws:iam::1234567890:user/username + AKIATESTACCESSKEY + 1234567890 + + + 7ae1ff87-8867-4b21-916b-4b44bef35345 + + ` + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, xmlResponse) + default: + t.FailNow() + } + })) + defer testServer.Close() + + // Server returns 406 status code not acceptable response + res, err = StsGetCallerIdentity(testServer.URL+"?case=notAcceptable", "date", "authorization", "token") + if res != nil { + t.Errorf("Error is expected so response should be nil") + } + errors.As(err, &errServer) + if errServer.Type != AWSStsServerRejection { + t.Errorf("Server error type invalid request error was expected because server returns 406 status code") + } + + // Server returns bad response + res, err = StsGetCallerIdentity(testServer.URL+"?case=badResponse", "date", "authorization", "token") + if res != nil { + t.Errorf("Error is expected so response should be nil") + } + errors.As(err, &errServer) + if errServer.Type != AWSStsServerResponse { + t.Errorf("Server error type invalid request error was expected because server returns bad response") + } + + // Server returns good response + res, err = StsGetCallerIdentity(testServer.URL+"?case=goodResponse", "date", "authorization", "token") + if err != nil { + t.Errorf("Good response is expected so response should be nil") + } + if res.GetCallerIdentityResult.Arn != "arn:aws:iam::1234567890:user/username" { + t.Errorf("Returned server response should include users arn") + } +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..dbb8936 --- /dev/null +++ b/server/server.go @@ -0,0 +1,71 @@ +package server + +import ( + "fmt" + "net/http" +) + +type GetRequestParameters func(r *http.Request) (awsRegion, xAmzDate, xAmzSecurityToken, authorization string, err error) +type IsIamIdentityValid func(awsIamIdentityArn string) (valid bool) + +func DefaultGetRequestParameters(r *http.Request) (awsRegion, xAmzDate, xAmzSecurityToken, authorization string, err error) { + xAmzDate = r.Header.Get("x-amz-date") + if xAmzDate == "" { + return "", "", "", "", errorCustom{Type: AwsStsInvalidParameter, Err: fmt.Errorf("missing x-amz-date header")} + } + xAmzSecurityToken = r.Header.Get("x-amz-security-token") + if xAmzSecurityToken == "" { + return "", "", "", "", errorCustom{Type: AwsStsInvalidParameter, Err: fmt.Errorf("missing x-amz-security-token header")} + } + authorization = r.Header.Get("authorization") + if authorization == "" { + return "", "", "", "", errorCustom{Type: AwsStsInvalidParameter, Err: fmt.Errorf("missing authorization header")} + } + awsRegion = getAwsIamRegion(authorization) + if awsRegion == "" { + return "", "", "", "", errorCustom{Type: AwsStsInvalidParameter, Err: fmt.Errorf("missing AWS IAM region")} + } + return awsRegion, xAmzDate, xAmzSecurityToken, authorization, nil +} + +func DefaultIsIamIdentityValid(awsIamIdentityArns ...string) IsIamIdentityValid { + authorized := make(map[string]bool) + for _, awsIamArn := range awsIamIdentityArns { + if !isValidAwsIamArn(awsIamArn) { + panic("invalid AWS IAM ARN specified " + awsIamArn) + } + authorized[awsIamArn] = true + } + return func(awsIamIdentityArn string) bool { + return authorized[awsIamIdentityArn] + } +} + +var getAwsStsUrl = func(awsRegion string) string { + return fmt.Sprintf("https://sts.%s.amazonaws.com", awsRegion) +} + +func AuthenticateAwsIamIdentity(getRequestParameters GetRequestParameters, isIamIdentityValid IsIamIdentityValid) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + awsRegion, xAmzDate, xAmzSecurityToken, authorization, err := getRequestParameters(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + awsCallerIdentity, err := StsGetCallerIdentity(getAwsStsUrl(awsRegion), xAmzDate, xAmzSecurityToken, authorization) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !isIamIdentityValid(awsCallerIdentity.GetCallerIdentityResult.Arn) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..3d58bb7 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,180 @@ +package server + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +const ( + awsRegionBase = "eu-central-1" + xAmzDateBase = "20230730T101440Z" + xAmzSecurityTokenBase = "thisIsSecurityToken" + authorizationBase = "AWS4-HMAC-SHA256 Credential=CREDENTIALS/20230730/" + awsRegionBase + "/sts/aws4_request, SignedHeaders=content-length;host;x-amz-date;x-amz-security-token, Signature=signature" +) + +func TestDefaultGetRequestParameters(t *testing.T) { + + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("https://sts.%s.amazonaws.com", awsRegionBase), nil) + var err error + + _, _, _, _, err = DefaultGetRequestParameters(req) + if err == nil { + t.Errorf("Expected error since no default headers were provided") + } + req.Header.Set("x-amz-date", xAmzDateBase) + _, _, _, _, err = DefaultGetRequestParameters(req) + if err == nil { + t.Errorf("Expected error since only x-amz-date was provided") + } + req.Header.Set("x-amz-security-token", xAmzSecurityTokenBase) + _, _, _, _, err = DefaultGetRequestParameters(req) + if err == nil { + t.Errorf("Expected error since only x-amz-date and x-amz-security-token were provided") + } + + req.Header.Set("authorization", authorizationBase) + awsRegion, xAmzDate, xAmzSecurityToken, authorization, err := DefaultGetRequestParameters(req) + if err != nil { + t.Errorf("Expected no error since all parameters were provided") + } + if awsRegion != awsRegionBase { + t.Errorf("Expected %s, but got %s", awsRegionBase, awsRegion) + } + if xAmzDate != xAmzDateBase { + t.Errorf("Expected %s, but got %s", xAmzDateBase, xAmzDate) + } + if xAmzSecurityToken != xAmzSecurityTokenBase { + t.Errorf("Expected %s, but got %s", xAmzSecurityTokenBase, xAmzSecurityToken) + } + if authorization != authorizationBase { + t.Errorf("Expected %s, but got %s", authorizationBase, authorization) + } + +} + +func TestDefaultIsIamIdentityValid(t *testing.T) { + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic with invalid ARN, but got none") + } + }() + DefaultIsIamIdentityValid("invalid-arn") + + validator := DefaultIsIamIdentityValid( + "arn:aws:iam::1234567890:user/username1", + "arn:aws:iam::1234567890:role/username2", + ) + if !validator("arn:aws:iam::1234567890:user/username1") { + t.Errorf("Expected username1 to be valid, but got invalid") + } + if !validator("arn:aws:iam::1234567890:role/username2") { + t.Errorf("Expected username2 to be valid, but got invalid") + } + if validator("arn:aws:iam::1234567890:role/username3") { + t.Errorf("Expected username3 invalid ARN to be invalid, but got valid") + } + if validator("invalid-arn") { + t.Errorf("Expected invalid arn invalid ARN to be invalid, but got valid") + } +} + +func TestAuthenticateAwsIamIdentity(t *testing.T) { + + testAwsStsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Header.Get("Authorization") { + case "wrongStsCredentials": + w.WriteHeader(http.StatusUnauthorized) + return + case "invalidIamArn": + xmlResponse := ` + + + arn:aws:iam::1234567890:user/username + AKIATESTACCESSKEY + 1234567890 + + + 7ae1ff87-8867-4b21-916b-4b44bef35345 + + ` + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, xmlResponse) + return + case "validIamArn": + xmlResponse := ` + + + arn:aws:iam::1234567890:user/username1 + AKIATESTACCESSKEY + 1234567890 + + + 7ae1ff87-8867-4b21-916b-4b44bef35345 + + ` + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, xmlResponse) + return + } + t.FailNow() + })) + defer testAwsStsServer.Close() + getAwsStsUrl = func(awsRegion string) string { + return testAwsStsServer.URL + } + + authenticateMiddleware := AuthenticateAwsIamIdentity( + func(r *http.Request) (awsRegion, xAmzDate, xAmzSecurityToken, authorization string, err error) { + switch r.URL.Query().Get("scenario") { + case "badRequest": + return "", "", "", "", fmt.Errorf("bad request") + case "wrongStsCredentials": + return "eu-central-1", "20230730T101440Z", "wrongStsCredentials", "AWS4-HMAC-SHA256...", nil + case "invalidIamArn": + return "eu-central-1", "20230730T101440Z", "invalidIamArn", "AWS4-HMAC-SHA256...", nil + case "validIamArn": + return "eu-central-1", "20230730T101440Z", "validIamArn", "AWS4-HMAC-SHA256...", nil + } + t.FailNow() + return "", "", "", "", nil + }, + func(awsIamIdentityArn string) bool { + return awsIamIdentityArn == "arn:aws:iam::1234567890:user/username1" + }, + ) + + mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + requestRecorder1 := httptest.NewRecorder() + request1 := httptest.NewRequest(http.MethodGet, "http://localhost:8080/?scenario=badRequest", nil) + authenticateMiddleware(mockHandler).ServeHTTP(requestRecorder1, request1) + if requestRecorder1.Code != http.StatusBadRequest { + t.Errorf("Expected status code %d, but got %d", http.StatusBadRequest, requestRecorder1.Code) + } + + requestRecorder2 := httptest.NewRecorder() + request2 := httptest.NewRequest(http.MethodGet, "http://localhost:8080/?scenario=wrongStsCredentials", nil) + authenticateMiddleware(mockHandler).ServeHTTP(requestRecorder2, request2) + if requestRecorder2.Code != http.StatusInternalServerError { + t.Errorf("Expected status code %d, but got %d", http.StatusInternalServerError, requestRecorder2.Code) + } + + requestRecorder3 := httptest.NewRecorder() + request3 := httptest.NewRequest(http.MethodGet, "http://localhost:8080/?scenario=invalidIamArn", nil) + authenticateMiddleware(mockHandler).ServeHTTP(requestRecorder3, request3) + if requestRecorder3.Code != http.StatusUnauthorized { + t.Errorf("Expected status code %d, but got %d", http.StatusUnauthorized, requestRecorder3.Code) + } + + requestRecorder4 := httptest.NewRecorder() + request4 := httptest.NewRequest(http.MethodGet, "http://localhost:8080/?scenario=validIamArn", nil) + authenticateMiddleware(mockHandler).ServeHTTP(requestRecorder4, request4) + if requestRecorder4.Code != http.StatusOK { + t.Errorf("Expected status code %d, but got %d", http.StatusOK, requestRecorder4.Code) + } +} diff --git a/server/utils.go b/server/utils.go new file mode 100644 index 0000000..546f0fc --- /dev/null +++ b/server/utils.go @@ -0,0 +1,32 @@ +package server + +import ( + "regexp" + "strings" +) + +func isValidAwsIamArn(awsIamIdentityArn string) bool { + iamArnRegex := regexp.MustCompile(`^arn:aws:iam::[0-9]+:(user|role)/([a-zA-Z0-9_\.\-]+)$`) + return iamArnRegex.MatchString(awsIamIdentityArn) +} + +func isValidAwsRegion(awsRegion string) bool { + awsRegionRegex := regexp.MustCompile(`^[a-z]{2}-[a-z]+-\d{1}$`) + return awsRegionRegex.MatchString(awsRegion) +} + +func getAwsIamRegion(authorization string) string { + authorizationComponents := strings.Split(authorization, " ") + if len(authorizationComponents) != 4 { + return "" + } + authorizationCredentials := strings.Split(authorizationComponents[1], "/") + if len(authorizationCredentials) != 5 { + return "" + } + region := authorizationCredentials[2] + if !isValidAwsRegion(region) { + return "" + } + return region +} diff --git a/server/utils_test.go b/server/utils_test.go new file mode 100644 index 0000000..2498717 --- /dev/null +++ b/server/utils_test.go @@ -0,0 +1,81 @@ +package server + +import "testing" + +func TestIsValidAwsIamArn(t *testing.T) { + testCases := []struct { + input string + expected bool + }{ + {"", false}, + {"arn:aws:iam::1234567890:user", false}, + {"arn:aws:iam::1234567890:/username", false}, + {"arn:aws:iam:::user/username", false}, + {"arn:aws:::1234567890:user/username", false}, + {"arn::iam::1234567890:user/username", false}, + {"arn:aws:s3:::bucket-name", false}, + {"arn:aws:dynamodb:region:account-id:table/table-name", false}, + {"arn:aws:iam::1234567890:user/username/username", false}, + {"arn:aws:iam::1234567890:user/username", true}, + {"arn:aws:iam::1234567890:role/rolename", true}, + } + + for _, tc := range testCases { + result := isValidAwsIamArn(tc.input) + if result != tc.expected { + t.Errorf("Input: %s, Expected: %t, Got: %t", tc.input, tc.expected, result) + } + } +} +func TestIsValidAwsRegion(t *testing.T) { + testCases := []struct { + input string + expected bool + }{ + {"", false}, + {"us-west-2a", false}, + {"us-west-", false}, + {"uswest1", false}, + {"us-east-01", false}, + {"us-west", false}, + {"useast-1", false}, + {"us-east-10", false}, + {"us-east-a", false}, + {"us-east-01a", false}, + {"us-east-1", true}, + {"us-west-2", true}, + {"eu-central-1", true}, + {"ap-southeast-2", true}, + {"eu-west-1", true}, + {"sa-east-1", true}, + {"us-west-1", true}, + } + + for _, tc := range testCases { + result := isValidAwsRegion(tc.input) + if result != tc.expected { + t.Errorf("Input: %s, Expected: %t, Got: %t", tc.input, tc.expected, result) + } + } +} + +func TestGetAwsIamRegion(t *testing.T) { + testCases := []struct { + input string + awsRegion string + }{ + {"", ""}, + {"AWS4-HMAC-SHA256-Credential=credential/20160126/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token, Signature=signature", ""}, + {"AWS4-HMAC-SHA256-Credential=credential/20160126/us-east-1/sts/aws4_request, Signature=signature", ""}, + {"Credential=credential/20160126/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token", ""}, + {"AWS4-HMAC-SHA256-Credential=credential/20160126/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token", ""}, + {"AWS4-HMAC-SHA256 Credential=credential/20160126/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-security-token, Signature=signature", "us-east-1"}, + } + + for _, tc := range testCases { + awsRegion := getAwsIamRegion(tc.input) + if awsRegion != tc.awsRegion { + t.Errorf("Input: %s, Expected region: %s, Got: %s", tc.input, tc.awsRegion, awsRegion) + } + } +}