diff --git a/pkg/credentials/iam_aws.go b/pkg/credentials/iam_aws.go index 5732f2e4b..34c593fdf 100644 --- a/pkg/credentials/iam_aws.go +++ b/pkg/credentials/iam_aws.go @@ -22,6 +22,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/http" "net/url" "os" @@ -57,14 +58,35 @@ const ( ) // https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html -func getEndpoint(endpoint string) (string, bool) { +func getEndpoint(endpoint string) (string, bool, error) { + ecsFullURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") + ecsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + if endpoint != "" { - return endpoint, os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" + return endpoint, ecsURI != "" || ecsFullURI != "", nil + } + if ecsFullURI != "" { + u, err := url.Parse(ecsFullURI) + if err != nil { + return "", false, err + } + host := u.Hostname() + if host == "" { + return "", false, fmt.Errorf("can't parse host from uri: %s", ecsFullURI) + } + + if loopback, err := isLoopback(host); loopback { + return ecsFullURI, true, nil + } else if err != nil { + return "", false, err + } else { + return "", false, fmt.Errorf("host is not on a loopback address: %s", host) + } } - if ecsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"); ecsURI != "" { - return fmt.Sprintf("%s%s", defaultECSRoleEndpoint, ecsURI), true + if ecsURI != "" { + return fmt.Sprintf("%s%s", defaultECSRoleEndpoint, ecsURI), true, nil } - return defaultIAMRoleEndpoint, false + return defaultIAMRoleEndpoint, false, nil } // NewIAM returns a pointer to a new Credentials object wrapping the IAM. @@ -82,9 +104,14 @@ func NewIAM(endpoint string) *Credentials { // Error will be returned if the request fails, or unable to extract // the desired func (m *IAM) Retrieve() (Value, error) { - endpoint, isEcsTask := getEndpoint(m.endpoint) var roleCreds ec2RoleCredRespBody var err error + + endpoint, isEcsTask, err := getEndpoint(m.endpoint) + if err != nil { + return Value{}, err + } + if isEcsTask { roleCreds, err = getEcsTaskCredentials(m.Client, endpoint) } else { @@ -248,3 +275,18 @@ func getCredentials(client *http.Client, endpoint string) (ec2RoleCredRespBody, return respCreds, nil } + +// isLoopback identifies if a host is on a loopback address +func isLoopback(host string) (bool, error) { + ips, err := net.LookupHost(host) + if err != nil { + return false, err + } + for _, ip := range ips { + if !net.ParseIP(ip).IsLoopback() { + return false, nil + } + } + + return true, nil +} diff --git a/pkg/credentials/iam_aws_test.go b/pkg/credentials/iam_aws_test.go index 90f980693..d502a8ace 100644 --- a/pkg/credentials/iam_aws_test.go +++ b/pkg/credentials/iam_aws_test.go @@ -243,3 +243,33 @@ func TestEcsTask(t *testing.T) { t.Error("Expected creds to be expired.") } } + +func TestEcsTaskFullURI(t *testing.T) { + server := initEcsTaskTestServer("2014-12-16T01:51:37Z") + defer server.Close() + p := &IAM{ + Client: http.DefaultClient, + } + os.Setenv("AWS_CONTAINER_CREDENTIALS_FULL_URI", + fmt.Sprintf("%s%s", server.URL, "/v2/credentials?id=task_credential_id")) + creds, err := p.Retrieve() + os.Unsetenv("AWS_CONTAINER_CREDENTIALS_FULL_URI") + if err != nil { + t.Errorf("Unexpected failure %s", err) + } + if "accessKey" != creds.AccessKeyID { + t.Errorf("Expected \"accessKey\", got %s", creds.AccessKeyID) + } + + if "secret" != creds.SecretAccessKey { + t.Errorf("Expected \"secret\", got %s", creds.SecretAccessKey) + } + + if "token" != creds.SessionToken { + t.Errorf("Expected \"token\", got %s", creds.SessionToken) + } + + if !p.IsExpired() { + t.Error("Expected creds to be expired.") + } +}