diff --git a/auth/providers/azure/azure_test.go b/auth/providers/azure/azure_test.go index e56d7867..f698d592 100644 --- a/auth/providers/azure/azure_test.go +++ b/auth/providers/azure/azure_test.go @@ -46,7 +46,7 @@ var jsonLib = jsoniter.ConfigCompatibleWithStandardLibrary const ( username = "nahid" objectID = "abc-123d4" - loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}` + loginResp = `{ "token_type": "Bearer", "expires_on": 1732881796, "access_token": "%v"}` accessToken = `{ "aud": "client_id", "iss" : "%v", "upn": "nahid", "groups": [ "1", "2", "3"] }` accessTokenWithOid = `{ "aud": "client_id","iss" : "%v", "oid": "abc-123d4", "groups": [ "1", "2", "3"] }` accessTokenWithUpnAndOid = `{ "aud": "client_id","iss" : "%v", "upn": "nahid", "oid": "abc-123d4", "groups": [ "1", "2", "3"] }` diff --git a/auth/providers/azure/graph/aks_tokenprovider_test.go b/auth/providers/azure/graph/aks_tokenprovider_test.go index dbe48367..c52badf6 100644 --- a/auth/providers/azure/graph/aks_tokenprovider_test.go +++ b/auth/providers/azure/graph/aks_tokenprovider_test.go @@ -28,7 +28,7 @@ func TestAKSTokenProvider(t *testing.T) { inputAccessToken = "inputAccessToken" oboAccessToken = "oboAccessToken" tenantID = "tenantID" - oboResponse = `{"token_type":"Bearer","expires_in":3599,"access_token":"%s"}` + oboResponse = `{"token_type":"Bearer","expires_on":1732881796,"access_token":"%s"}` expectedContentType = "application/json" expectedTokneType = "Bearer" ) diff --git a/auth/providers/azure/graph/clientcredential_tokenprovider_test.go b/auth/providers/azure/graph/clientcredential_tokenprovider_test.go index 7e6073ed..acc69a95 100644 --- a/auth/providers/azure/graph/clientcredential_tokenprovider_test.go +++ b/auth/providers/azure/graph/clientcredential_tokenprovider_test.go @@ -30,7 +30,7 @@ func TestClientCredentialTokenProvider(t *testing.T) { clientID = "fakeID" clientSecret = "fakeSecret" scope = "https://graph.microsoft.com/.default" - oboResponse = `{"token_type":"Bearer","expires_in":3599,"access_token":"%s"}` + oboResponse = `{"token_type":"Bearer","expires_on":1732881796,"access_token":"%s"}` expectedContentType = "application/x-www-form-urlencoded" expectedGrantType = "client_credentials" expectedTokneType = "Bearer" diff --git a/auth/providers/azure/graph/graph.go b/auth/providers/azure/graph/graph.go index 35bb9066..50eb5358 100644 --- a/auth/providers/azure/graph/graph.go +++ b/auth/providers/azure/graph/graph.go @@ -72,7 +72,8 @@ var ( ) const ( - expiryDelta = 60 * time.Second + // Time delta to refresh token before expiry + tokenExpiryDelta = 300 * time.Second getMemberGroupsTimeout = 23 * time.Second getterName = "ms-graph" arcAuthMode = "arc" @@ -327,11 +328,13 @@ func (u *UserInfo) RefreshToken(ctx context.Context, token string) error { if err != nil { return errors.Errorf("%s: failed to refresh token: %s", u.tokenProvider.Name(), err) } + klog.Infof("Token received, expires_at %d", resp.ExpiresOn) // Set the authorization headers for future requests u.headers.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) - expIn := time.Duration(resp.Expires) * time.Second - u.expires = time.Now().Add(expIn - expiryDelta) - klog.Infof("Token refreshed successfully on %s. Expire at:%s", time.Now(), u.expires) + // Use ExpiresOn to set the expiration time + expOn := time.Unix(int64(resp.ExpiresOn), 0) + u.expires = expOn.Add(-tokenExpiryDelta) + klog.Infof("Token refreshed successfully at %s. Expire at set to: %s", time.Now(), u.expires) } return nil diff --git a/auth/providers/azure/graph/graph_test.go b/auth/providers/azure/graph/graph_test.go index 73a35636..4a76a8f5 100644 --- a/auth/providers/azure/graph/graph_test.go +++ b/auth/providers/azure/graph/graph_test.go @@ -106,12 +106,12 @@ func TestLogin(t *testing.T) { validToken := "blackbriar" validBody := `{ "token_type": "Bearer", - "expires_in": 3599, - "access_token": "%s" + "access_token": "%s", + "expires_on": %d }` - ts, u := getAuthServerAndUserInfo(http.StatusOK, fmt.Sprintf(validBody, validToken), "jason", "bourne") + expiresOn := time.Now().Add(time.Second * 3599) + ts, u := getAuthServerAndUserInfo(http.StatusOK, fmt.Sprintf(validBody, validToken, expiresOn.Unix()), "jason", "bourne") defer ts.Close() - err := u.RefreshToken(ctx, "") if err != nil { t.Errorf("Error when trying to log in: %s", err) @@ -122,6 +122,12 @@ func TestLogin(t *testing.T) { if !time.Now().Before(u.expires) { t.Errorf("Expiry not set properly. Expected it to be after the current time. Actual: %v", u.expires) } + // Normalize to second precision for comparison + expectedExpiresOn := expiresOn.Add(-tokenExpiryDelta).Truncate(time.Second) + actualExpires := u.expires.Truncate(time.Second) + if !expectedExpiresOn.Equal(actualExpires) { + t.Errorf("Expiry not set properly. Expected it to be %v equal to expiresOn. Actual: %v", expectedExpiresOn, actualExpires) + } }) t.Run("unsuccessful login", func(t *testing.T) { @@ -494,7 +500,7 @@ func TestGetGroups(t *testing.T) { mux := http.NewServeMux() mux.Handle("/login", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) - _, _ = w.Write([]byte(`{ "token_type": "Bearer", "expires_in": 8459, "access_token": "secret"}`)) + _, _ = w.Write([]byte(`{ "token_type": "Bearer", "expires_on": 1732881796, "access_token": "secret"}`)) })) mux.Handle("/users/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) diff --git a/auth/providers/azure/graph/msi_tokenprovider.go b/auth/providers/azure/graph/msi_tokenprovider.go index 8d3f4e26..b79102a2 100644 --- a/auth/providers/azure/graph/msi_tokenprovider.go +++ b/auth/providers/azure/graph/msi_tokenprovider.go @@ -96,9 +96,10 @@ func (u *msiTokenProvider) Acquire(ctx context.Context, token string) (AuthRespo } authResp.TokenType = tokenResp.TokenType - authResp.Expires, err = strconv.Atoi(tokenResp.ExpiresIn) + // This is the actual time the token expires in Unix time + authResp.ExpiresOn, err = strconv.Atoi(tokenResp.ExpiresOn) if err != nil { - return authResp, errors.Wrapf(err, "Failed to decode expiry date") + return authResp, errors.Wrapf(err, "Failed to decode expires_on field for token") } authResp.Token = tokenResp.AccessToken diff --git a/auth/providers/azure/graph/msi_tokenprovider_test.go b/auth/providers/azure/graph/msi_tokenprovider_test.go index b7a35fc0..e3ce8be4 100644 --- a/auth/providers/azure/graph/msi_tokenprovider_test.go +++ b/auth/providers/azure/graph/msi_tokenprovider_test.go @@ -28,9 +28,10 @@ func TestMSITokenProvider(t *testing.T) { const ( inputAccessToken = "inputAccessToken" msiAccessToken = "msiAccessToken" - tokenResponse = `{"token_type":"Bearer","expires_in":"3599","access_token":"%s"}` + tokenResponse = `{"access_token":"%s","expires_in":"86700","refresh_token":"","expires_on":"%d","not_before":"1732795096","resource":"https://management.azure.com","token_type":"Bearer"}` expectedContentType = "application/json" expectedTokenType = "Bearer" + expectedExpiresOn = 1732881796 ) t.Run("Upon Success Response", func(t *testing.T) { @@ -41,7 +42,7 @@ func TestMSITokenProvider(t *testing.T) { if req.Header.Get("Content-Type") != expectedContentType { t.Errorf("expected content type: %s, actual: %s", expectedContentType, req.Header.Get("Content-Type")) } - _, _ = rw.Write([]byte(fmt.Sprintf(tokenResponse, msiAccessToken))) + _, _ = rw.Write([]byte(fmt.Sprintf(tokenResponse, msiAccessToken, expectedExpiresOn))) }) defer stopMSITestServer(t, s) @@ -59,6 +60,10 @@ func TestMSITokenProvider(t *testing.T) { if resp.TokenType != expectedTokenType { t.Errorf("expected token type: Bearer, actual: %s", resp.TokenType) } + + if resp.ExpiresOn != expectedExpiresOn { + t.Errorf("expected expires on: %d, actual: %d", expectedExpiresOn, resp.ExpiresOn) + } }) t.Run("Upon Error Response", func(t *testing.T) { diff --git a/auth/providers/azure/graph/obo_tokenprovider_test.go b/auth/providers/azure/graph/obo_tokenprovider_test.go index ce49c868..4f47d564 100644 --- a/auth/providers/azure/graph/obo_tokenprovider_test.go +++ b/auth/providers/azure/graph/obo_tokenprovider_test.go @@ -31,7 +31,7 @@ func TestOBOTokenProvider(t *testing.T) { clientID = "fakeID" clientSecret = "fakeSecret" scope = "https://graph.microsoft.com/.default" - oboResponse = `{"token_type":"Bearer","expires_in":3599,"access_token":"%s"}` + oboResponse = `{"token_type":"Bearer","expires_on":1732881796,"access_token":"%s"}` expectedContentType = "application/x-www-form-urlencoded" expectedTokenUse = "on_behalf_of" expectedGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" diff --git a/auth/providers/azure/graph/types.go b/auth/providers/azure/graph/types.go index 0263f95f..e033f2b8 100644 --- a/auth/providers/azure/graph/types.go +++ b/auth/providers/azure/graph/types.go @@ -19,8 +19,9 @@ package graph // AuthResponse represents a response from the MS Graph auth API type AuthResponse struct { TokenType string `json:"token_type"` - Expires int `json:"expires_in"` Token string `json:"access_token"` + // This is the actual time the token expires on in Unix time + ExpiresOn int `json:"expires_on"` } // NOTE: These below are partial implementations of the API objects containing diff --git a/authz/providers/azure/azure_test.go b/authz/providers/azure/azure_test.go index 700d3cfc..93801cb6 100644 --- a/authz/providers/azure/azure_test.go +++ b/authz/providers/azure/azure_test.go @@ -38,7 +38,7 @@ import ( ) const ( - loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}` + loginResp = `{ "token_type": "Bearer", "expires_on": 1732881796, "access_token": "%v"}` httpClientRetryCount = 2 ) diff --git a/authz/providers/azure/rbac/rbac.go b/authz/providers/azure/rbac/rbac.go index 88c726d9..48bd31f8 100644 --- a/authz/providers/azure/rbac/rbac.go +++ b/authz/providers/azure/rbac/rbac.go @@ -49,13 +49,14 @@ import ( ) const ( - managedClusters = "Microsoft.ContainerService/managedClusters" - fleets = "Microsoft.ContainerService/fleets" - connectedClusters = "Microsoft.Kubernetes/connectedClusters" - checkAccessPath = "/providers/Microsoft.Authorization/checkaccess" - checkAccessAPIVersion = "2018-09-01-preview" - remainingSubReadARMHeader = "x-ms-ratelimit-remaining-subscription-reads" - expiryDelta = 60 * time.Second + managedClusters = "Microsoft.ContainerService/managedClusters" + fleets = "Microsoft.ContainerService/fleets" + connectedClusters = "Microsoft.Kubernetes/connectedClusters" + checkAccessPath = "/providers/Microsoft.Authorization/checkaccess" + checkAccessAPIVersion = "2018-09-01-preview" + remainingSubReadARMHeader = "x-ms-ratelimit-remaining-subscription-reads" + // Time delta to refresh token before expiry + tokenExpiryDelta = 300 * time.Second checkaccessContextTimeout = 23 * time.Second correlationRequestIDHeader = "x-ms-correlation-request-id" ) @@ -224,9 +225,11 @@ func (a *AccessInfo) RefreshToken(ctx context.Context) error { // Set the authorization headers for future requests a.headers.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) - expIn := time.Duration(resp.Expires) * time.Second - a.expiresAt = time.Now().Add(expIn - expiryDelta) - klog.Infof("Token refreshed successfully on %s. Expire at:%s", time.Now(), a.expiresAt) + + // Use ExpiresOn to set the expiration time + expOn := time.Unix(int64(resp.ExpiresOn), 0) + a.expiresAt = expOn.Add(-tokenExpiryDelta) + klog.Infof("Token refreshed successfully at %s. Expire at set to: %s", time.Now(), a.expiresAt) } return nil diff --git a/authz/providers/azure/rbac/rbac_test.go b/authz/providers/azure/rbac/rbac_test.go index a5fd7eb5..2c90c7b0 100644 --- a/authz/providers/azure/rbac/rbac_test.go +++ b/authz/providers/azure/rbac/rbac_test.go @@ -177,10 +177,11 @@ func TestLogin(t *testing.T) { validToken := "blackbriar" validBody := `{ "token_type": "Bearer", - "expires_in": 3599, - "access_token": "%s" + "access_token": "%s", + "expires_on": %d }` - ts, u := getAuthServerAndAccessInfo(http.StatusOK, fmt.Sprintf(validBody, validToken), "jason", "bourne") + expiresOn := time.Now().Add(time.Second * 3599) + ts, u := getAuthServerAndAccessInfo(http.StatusOK, fmt.Sprintf(validBody, validToken, expiresOn.Unix()), "jason", "bourne") defer ts.Close() ctx := context.Background() @@ -194,6 +195,13 @@ func TestLogin(t *testing.T) { if !time.Now().Before(u.expiresAt) { t.Errorf("Expiry not set properly. Expected it to be after the current time. Actual: %v", u.expiresAt) } + + // Normalize to second precision for comparison + expectedExpiresOn := expiresOn.Add(-tokenExpiryDelta).Truncate(time.Second) + actualExpires := u.expiresAt.Truncate(time.Second) + if !expectedExpiresOn.Equal(actualExpires) { + t.Errorf("Expiry not set properly. Expected it to be %v equal to expiresOn. Actual: %v", expectedExpiresOn, actualExpires) + } }) t.Run("unsuccessful login", func(t *testing.T) {