diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 9c2398bc8..ea409d14b 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -13,8 +13,6 @@ const ( STSThrottling = "sts_throttling" Unknown = "uknown_user" Success = "success" - STSGlobal = "sts_global" - STSRegional = "sts_regional" ) var authenticatorMetrics Metrics @@ -72,21 +70,21 @@ func createMetrics(reg prometheus.Registerer) Metrics { Namespace: Namespace, Name: "sts_connection_failures_total", Help: "Sts call could not succeed or timedout", - }, []string{"StsEndpointType"}, + }, []string{"StsRegion"}, ), StsThrottling: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_throttling_total", Help: "Sts call got throttled", - }, []string{"StsEndpointType"}, + }, []string{"StsRegion"}, ), StsResponses: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_responses_total", Help: "Sts responses with error code label", - }, []string{"ResponseCode", "StsEndpointType"}, + }, []string{"ResponseCode", "StsRegion"}, ), Latency: factory.NewHistogramVec( prometheus.HistogramOpts{ diff --git a/pkg/token/token.go b/pkg/token/token.go index 478a1aad2..d704b3a97 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -568,19 +568,16 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { req.Header.Set(clusterIDHeader, v.clusterID) req.Header.Set("accept", "application/json") - stsEndpointType := metrics.STSRegional - if parsedURL.Host == "sts.amazonaws.com" { - stsEndpointType = metrics.STSGlobal - } + stsRegion := getStsRegion(parsedURL.Host) response, err := v.client.Do(req) if err != nil { - metrics.Get().StsConnectionFailure.WithLabelValues(stsEndpointType).Inc() + metrics.Get().StsConnectionFailure.WithLabelValues(stsRegion).Inc() // special case to avoid printing the full URL if possible if urlErr, ok := err.(*url.Error); ok { - return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", urlErr.Err, stsEndpointType)) + return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", urlErr.Err, stsRegion)) } - return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", err, stsEndpointType)) + return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", err, stsRegion)) } defer response.Body.Close() @@ -589,16 +586,16 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err)) } - metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode), stsEndpointType).Inc() + metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode), stsRegion).Inc() if response.StatusCode != 200 { responseStr := string(responseBody[:]) // refer to https://docs.aws.amazon.com/STS/latest/APIReference/CommonErrors.html and log // response body for STS Throttling is {"Error":{"Code":"Throttling","Message":"Rate exceeded","Type":"Sender"},"RequestId":"xxx"} if strings.Contains(responseStr, "Throttling") { - metrics.Get().StsThrottling.WithLabelValues(stsEndpointType).Inc() + metrics.Get().StsThrottling.WithLabelValues(stsRegion).Inc() return nil, NewSTSThrottling(responseStr) } - return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d) on %s endpoint. Body: %s", response.StatusCode, stsEndpointType, responseStr)) + return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d) on %s endpoint. Body: %s", response.StatusCode, stsRegion, responseStr)) } var callerIdentity getCallerIdentityWrapper @@ -669,3 +666,11 @@ func hasSignedClusterIDHeader(paramsLower *url.Values) bool { } return false } + +func getStsRegion(host string) string { + parts := strings.Split(host, ".") + if host == "sts.amazonaws.com" { + return "global" + } + return parts[1] +}