Skip to content

Commit

Permalink
Merge pull request #6385 from Ticketmaster/use-sts-GetCallerIdentity
Browse files Browse the repository at this point in the history
provider/aws: Added sts:GetCallerIdentity to GetAccountId for federated logins
  • Loading branch information
radeksimko committed May 5, 2016
2 parents 351701e + a23bcf2 commit e32a8c1
Show file tree
Hide file tree
Showing 61 changed files with 28,003 additions and 863 deletions.
237 changes: 121 additions & 116 deletions Godeps/Godeps.json

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions builtin/providers/aws/auth_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
)

func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
// If we have creds from instance profile, we can use metadata API
if authProviderName == ec2rolecreds.ProviderName {
log.Println("[DEBUG] Trying to get account ID via AWS Metadata API")
Expand All @@ -42,16 +43,24 @@ func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
return parseAccountIdFromArn(*outUser.User.Arn)
}

// Then try IAM ListRoles
awsErr, ok := err.(awserr.Error)
// AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") {
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
}

log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles instead")

// Then try STS GetCallerIdentity
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return *outCallerIdentity.Account, nil
}
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)

// Then try IAM ListRoles
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles")
outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{
MaxItems: aws.Int64(int64(1)),
})
Expand Down
87 changes: 69 additions & 18 deletions builtin/providers/aws/auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
)

func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
Expand All @@ -28,10 +29,10 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
defer awsTs()

iamEndpoints := []*iamEndpoint{}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}
Expand All @@ -55,10 +56,10 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}
Expand All @@ -76,10 +77,36 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)

ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}

expectedAccountId := "123456789012"
if id != expectedAccountId {
t.Fatalf("Expected account ID: %s, given: %s", expectedAccountId, id)
}
}

func TestAWSGetAccountId_shouldBeValid_fromGetCallerIdentity(t *testing.T) {
iamEndpoints := []*iamEndpoint{
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{200, stsResponse_GetCallerIdentity_valid, "text/xml"},
},
}
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}
Expand All @@ -96,15 +123,19 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {
Request: &iamRequest{"POST", "/", "Action=GetUser&Version=2010-05-08"},
Response: &iamResponse{403, iamResponse_GetUser_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=GetCallerIdentity&Version=2011-06-15"},
Response: &iamResponse{403, stsResponse_GetCallerIdentity_unauthorized, "text/xml"},
},
&iamEndpoint{
Request: &iamRequest{"POST", "/", "Action=ListRoles&MaxItems=1&Version=2010-05-08"},
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}
Expand All @@ -126,10 +157,10 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}
Expand All @@ -151,10 +182,10 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) {
Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err == nil {
t.Fatal("Expected error when getting account ID")
}
Expand Down Expand Up @@ -586,15 +617,15 @@ func invalidAwsEnv(t *testing.T) func() {
return ts.Close
}

// getMockedAwsIamApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM server
func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
// getMockedAwsIamStsApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM & STS server
func getMockedAwsIamStsApi(endpoints []*iamEndpoint) (func(), *iam.IAM, *sts.STS) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf := new(bytes.Buffer)
buf.ReadFrom(r.Body)
requestBody := buf.String()

log.Printf("[DEBUG] Received IAM API %q request to %q: %s",
log.Printf("[DEBUG] Received API %q request to %q: %s",
r.Method, r.RequestURI, requestBody)

for _, e := range endpoints {
Expand Down Expand Up @@ -624,8 +655,8 @@ func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
CredentialsChainVerboseErrors: aws.Bool(true),
})
iamConn := iam.New(sess)

return ts.Close, iamConn
stsConn := sts.New(sess)
return ts.Close, iamConn, stsConn
}

func getEnv() *currentEnv {
Expand Down Expand Up @@ -718,6 +749,26 @@ const iamResponse_GetUser_unauthorized = `<ErrorResponse xmlns="https://iam.amaz
<RequestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestId>
</ErrorResponse>`

const stsResponse_GetCallerIdentity_valid = `<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>arn:aws:iam::123456789012:user/Alice</Arn>
<UserId>AKIAI44QH8DHBEXAMPLE</UserId>
<Account>123456789012</Account>
</GetCallerIdentityResult>
<ResponseMetadata>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ResponseMetadata>
</GetCallerIdentityResponse>`

const stsResponse_GetCallerIdentity_unauthorized = `<ErrorResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<Error>
<Type>Sender</Type>
<Code>AccessDenied</Code>
<Message>User: arn:aws:iam::123456789012:user/Bob is not authorized to perform: sts:GetCallerIdentity</Message>
</Error>
<RequestId>01234567-89ab-cdef-0123-456789abcdef</RequestId>
</ErrorResponse>`

const iamResponse_GetUser_federatedFailure = `<ErrorResponse xmlns="https://iam.amazonaws.com/doc/2010-05-08/">
<Error>
<Type>Sender</Type>
Expand Down
25 changes: 16 additions & 9 deletions builtin/providers/aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sts"
)

type Config struct {
Expand Down Expand Up @@ -94,8 +95,10 @@ type AWSClient struct {
s3conn *s3.S3
sqsconn *sqs.SQS
snsconn *sns.SNS
stsconn *sts.STS
redshiftconn *redshift.Redshift
r53conn *route53.Route53
accountid string
region string
rdsconn *rds.RDS
iamconn *iam.IAM
Expand Down Expand Up @@ -174,6 +177,9 @@ func (c *Config) Client() (interface{}, error) {
awsIamSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.IamEndpoint)})
client.iamconn = iam.New(awsIamSess)

log.Println("[INFO] Initializing STS connection")
client.stsconn = sts.New(sess)

err = c.ValidateCredentials(client.iamconn)
if err != nil {
errs = append(errs, err)
Expand All @@ -187,6 +193,11 @@ func (c *Config) Client() (interface{}, error) {
// http://docs.aws.amazon.com/general/latest/gr/sigv4_changes.html
usEast1Sess := sess.Copy(&aws.Config{Region: aws.String("us-east-1")})

accountId, err := GetAccountId(client.iamconn, client.stsconn, cp.ProviderName)
if err == nil {
client.accountid = accountId
}

log.Println("[INFO] Initializing DynamoDB connection")
dynamoSess := sess.Copy(&aws.Config{Endpoint: aws.String(c.DynamoDBEndpoint)})
client.dynamodbconn = dynamodb.New(dynamoSess)
Expand Down Expand Up @@ -217,7 +228,7 @@ func (c *Config) Client() (interface{}, error) {
log.Println("[INFO] Initializing Elastic Beanstalk Connection")
client.elasticbeanstalkconn = elasticbeanstalk.New(sess)

authErr := c.ValidateAccountId(client.iamconn, cp.ProviderName)
authErr := c.ValidateAccountId(client.accountid)
if authErr != nil {
errs = append(errs, authErr)
}
Expand Down Expand Up @@ -343,32 +354,28 @@ func (c *Config) ValidateCredentials(iamconn *iam.IAM) error {

// ValidateAccountId returns a context-specific error if the configured account
// id is explicitly forbidden or not authorised; and nil if it is authorised.
func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) error {
func (c *Config) ValidateAccountId(accountId string) error {
if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil {
return nil
}

log.Printf("[INFO] Validating account ID")
account_id, err := GetAccountId(iamconn, authProviderName)
if err != nil {
return err
}

if c.ForbiddenAccountIds != nil {
for _, id := range c.ForbiddenAccountIds {
if id == account_id {
if id == accountId {
return fmt.Errorf("Forbidden account ID (%s)", id)
}
}
}

if c.AllowedAccountIds != nil {
for _, id := range c.AllowedAccountIds {
if id == account_id {
if id == accountId {
return nil
}
}
return fmt.Errorf("Account ID not allowed (%s)", account_id)
return fmt.Errorf("Account ID not allowed (%s)", accountId)
}

return nil
Expand Down
26 changes: 24 additions & 2 deletions vendor/github.com/aws/aws-sdk-go/aws/client/default_retryer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e32a8c1

Please sign in to comment.