-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathawsauth.go
207 lines (178 loc) · 7.35 KB
/
awsauth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package awsbase
import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go"
"github.com/hashicorp/aws-sdk-go-base/v2/logging"
multierror "github.com/hashicorp/go-multierror"
)
// getAccountIDAndPartition gets the account ID and associated partition.
func getAccountIDAndPartition(ctx context.Context, iamClient *iam.Client, stsClient *sts.Client, authProviderName string) (string, string, error) {
var accountID, partition string
var err, errors error
if authProviderName == ec2rolecreds.ProviderName {
accountID, partition, err = getAccountIDAndPartitionFromEC2Metadata(ctx)
} else {
accountID, partition, err = getAccountIDAndPartitionFromIAMGetUser(ctx, iamClient)
}
if accountID != "" {
return accountID, partition, nil
}
errors = multierror.Append(errors, err)
accountID, partition, err = getAccountIDAndPartitionFromSTSGetCallerIdentity(ctx, stsClient)
if accountID != "" {
return accountID, partition, nil
}
errors = multierror.Append(errors, err)
accountID, partition, err = getAccountIDAndPartitionFromIAMListRoles(ctx, iamClient)
if accountID != "" {
return accountID, partition, nil
}
errors = multierror.Append(errors, err)
return accountID, partition, errors
}
// getAccountIDAndPartitionFromEC2Metadata gets the account ID and associated
// partition from EC2 metadata.
func getAccountIDAndPartitionFromEC2Metadata(ctx context.Context) (accountID string, partition string, err error) {
logger := logging.RetrieveLogger(ctx)
logger.Debug(ctx, "Retrieving account information from EC2 Metadata")
cfg := aws.Config{}
metadataClient := imds.NewFromConfig(cfg)
info, err := metadataClient.GetIAMInfo(ctx, &imds.GetIAMInfoInput{})
if err != nil {
// We can end up here if there's an issue with the instance metadata service
// or if we're getting credentials from AdRoll's Hologram (in which case IAMInfo will
// error out).
logger.Debug(ctx, "Unable to retrieve account information from EC2 Metadata", map[string]any{
"error": err,
})
return "", "", fmt.Errorf("retrieving account information via EC2 Metadata IAM information: %w", err)
}
accountID, partition, err = parseAccountIDAndPartitionFromARN(info.InstanceProfileArn)
if err != nil {
logger.Debug(ctx, "Unable to retrieve account information from EC2 Metadata", map[string]any{
"error": err,
})
return "", "", fmt.Errorf("retrieving account information from EC2 Metadata: %w", err)
} else {
logger.Info(ctx, "Retrieved account information from EC2 Metadata")
}
return
}
// getAccountIDAndPartitionFromIAMGetUser gets the account ID and associated
// partition from IAM.
func getAccountIDAndPartitionFromIAMGetUser(ctx context.Context, iamClient iam.GetUserAPIClient) (accountID string, partition string, err error) {
logger := logging.RetrieveLogger(ctx)
logger.Debug(ctx, "Retrieving account information via iam:GetUser")
output, err := iamClient.GetUser(ctx, &iam.GetUserInput{})
if err != nil {
// AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these
var apiErr smithy.APIError
if errors.As(err, &apiErr) {
switch apiErr.ErrorCode() {
case "AccessDenied", "InvalidClientTokenId", "ValidationError":
logger.Debug(ctx, "Retrieving account information via iam:GetUser: ignoring error", map[string]any{
"error": err,
})
return "", "", nil
}
}
logger.Debug(ctx, "Unable to retrieve account information via iam:GetUser", map[string]any{
"error": err,
})
return "", "", fmt.Errorf("retrieving account information via iam:GetUser: %w", err)
}
if output == nil || output.User == nil {
logger.Debug(ctx, "Unable to retrieve account information via iam:GetUser", map[string]any{
"error": "empty response",
})
return "", "", errors.New("retrieving account information via iam:GetUser: empty response")
}
accountID, partition, err = parseAccountIDAndPartitionFromARN(aws.ToString(output.User.Arn))
if err != nil {
logger.Debug(ctx, "Unable to retrieve account information via iam:GetUser", map[string]any{
"error": err,
})
return "", "", fmt.Errorf("retrieving account information via iam:GetUser: %w", err)
} else {
logger.Info(ctx, "Retrieved account information via iam:GetUser")
}
return
}
// getAccountIDAndPartitionFromIAMListRoles gets the account ID and associated
// partition from listing IAM roles.
func getAccountIDAndPartitionFromIAMListRoles(ctx context.Context, iamClient iam.ListRolesAPIClient) (accountID string, partition string, err error) {
logger := logging.RetrieveLogger(ctx)
logger.Debug(ctx, "Retrieving account information via iam:ListRoles")
output, err := iamClient.ListRoles(ctx, &iam.ListRolesInput{
MaxItems: aws.Int32(1),
})
if err != nil {
logger.Debug(ctx, "Unable to retrieve account information via iam:ListRoles", map[string]any{
"error": err,
})
return "", "", fmt.Errorf("retrieving account information via iam:ListRoles: %w", err)
}
if output == nil || len(output.Roles) < 1 {
logger.Debug(ctx, "Unable to retrieve account information via iam:ListRoles", map[string]any{
"error": "empty response",
})
return "", "", errors.New("retrieving account information via iam:ListRoles: empty response")
}
accountID, partition, err = parseAccountIDAndPartitionFromARN(aws.ToString(output.Roles[0].Arn))
if err != nil {
logger.Debug(ctx, "Unable to retrieve account information via iam:ListRoles", map[string]any{
"error": err,
})
return "", "", fmt.Errorf("retrieving account information via iam:ListRoles: %w", err)
} else {
logger.Info(ctx, "Retrieved account information via iam:ListRoles")
}
return
}
// getAccountIDAndPartitionFromSTSGetCallerIdentity gets the account ID and associated
// partition from STS caller identity.
func getAccountIDAndPartitionFromSTSGetCallerIdentity(ctx context.Context, stsClient *sts.Client) (accountID string, partition string, err error) {
logger := logging.RetrieveLogger(ctx)
logger.Debug(ctx, "Retrieving caller identity from STS")
output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
logger.Debug(ctx, "Unable to retrieve caller identity from STS", map[string]any{
"error": err,
})
return "", "", fmt.Errorf("retrieving caller identity from STS: %w", err)
}
if output == nil || output.Arn == nil {
logger.Debug(ctx, "Unable to retrieve caller identity from STS", map[string]any{
"error": "empty response",
})
return "", "", errors.New("retrieving caller identity from STS: empty response")
}
accountID, partition, err = parseAccountIDAndPartitionFromARN(aws.ToString(output.Arn))
if err != nil {
logger.Debug(ctx, "Unable to retrieve caller identity from STS", map[string]any{
"error": err,
})
return "", "", fmt.Errorf("retrieving caller identity from STS: %w", err)
} else {
logger.Info(ctx, "Retrieved caller identity from STS")
}
return
}
func parseAccountIDAndPartitionFromARN(inputARN string) (string, string, error) {
arn, err := arn.Parse(inputARN)
if err != nil {
return "", "", fmt.Errorf("parsing ARN (%s): %s", inputARN, err)
}
return arn.AccountID, arn.Partition, nil
}