Skip to content

Commit

Permalink
update code and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sushanth0910 committed Nov 13, 2024
1 parent e0a7fff commit f52d258
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
17 changes: 12 additions & 5 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
if err = v.verifyHost(parsedURL.Host); err != nil {
return nil, err
}
stsRegion, err := getStsRegion(parsedURL.Host)

if parsedURL.Path != "/" {
return nil, FormatError{"unexpected path in pre-signed URL"}
Expand Down Expand Up @@ -568,8 +569,6 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
req.Header.Set(clusterIDHeader, v.clusterID)
req.Header.Set("accept", "application/json")

stsRegion := getStsRegion(parsedURL.Host)

response, err := v.client.Do(req)
if err != nil {
metrics.Get().StsConnectionFailure.WithLabelValues(stsRegion).Inc()
Expand Down Expand Up @@ -667,10 +666,18 @@ func hasSignedClusterIDHeader(paramsLower *url.Values) bool {
return false
}

func getStsRegion(host string) string {
func getStsRegion(host string) (string, error) {
if host == "" {
return "", fmt.Errorf("host is empty")
}

parts := strings.Split(host, ".")
if len(parts) < 3 {
return "", fmt.Errorf("invalid host format: %v", host)
}

if host == "sts.amazonaws.com" {
return "global"
return "global", nil
}
return parts[1]
return parts[1], nil
}
25 changes: 25 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,28 @@ func TestGetWithSTS(t *testing.T) {
})
}
}

func TestGetStsRegion(t *testing.T) {
tests := []struct {
host string
expected string
wantErr bool
}{
{"sts.amazonaws.com", "global", false}, // Global endpoint
{"sts.us-west-2.amazonaws.com", "us-west-2", false}, // Valid regional endpoint
{"sts.eu-central-1.amazonaws.com", "eu-central-1", false}, // Another valid regional endpoint
{"", "", true}, // Empty input (expect error)
{"sts", "", true}, // Malformed input (expect error)
{"sts.wrongformat", "", true}, // Malformed input (expect error)
}

for _, test := range tests {
result, err := getStsRegion(test.host)
if (err != nil) != test.wantErr {
t.Errorf("getStsRegion(%q) error = %v, wantErr %v", test.host, err, test.wantErr)
}
if result != test.expected {
t.Errorf("getStsRegion(%q) = %q; expected %q", test.host, result, test.expected)
}
}
}

0 comments on commit f52d258

Please sign in to comment.