Skip to content

Commit

Permalink
Merge pull request #398 from saisankargochhayat/sai/guardFix
Browse files Browse the repository at this point in the history
[BUG/FIX] Azure RBAC fix refresh token logic
  • Loading branch information
weinong authored Dec 4, 2024
2 parents 49ce54b + 209b4cd commit b512fe6
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 32 deletions.
2 changes: 1 addition & 1 deletion auth/providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ var jsonLib = jsoniter.ConfigCompatibleWithStandardLibrary
const (
username = "nahid"
objectID = "abc-123d4"
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}`
loginResp = `{ "token_type": "Bearer", "expires_on": 1732881796, "access_token": "%v"}`
accessToken = `{ "aud": "client_id", "iss" : "%v", "upn": "nahid", "groups": [ "1", "2", "3"] }`
accessTokenWithOid = `{ "aud": "client_id","iss" : "%v", "oid": "abc-123d4", "groups": [ "1", "2", "3"] }`
accessTokenWithUpnAndOid = `{ "aud": "client_id","iss" : "%v", "upn": "nahid", "oid": "abc-123d4", "groups": [ "1", "2", "3"] }`
Expand Down
2 changes: 1 addition & 1 deletion auth/providers/azure/graph/aks_tokenprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestAKSTokenProvider(t *testing.T) {
inputAccessToken = "inputAccessToken"
oboAccessToken = "oboAccessToken"
tenantID = "tenantID"
oboResponse = `{"token_type":"Bearer","expires_in":3599,"access_token":"%s"}`
oboResponse = `{"token_type":"Bearer","expires_on":1732881796,"access_token":"%s"}`
expectedContentType = "application/json"
expectedTokneType = "Bearer"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestClientCredentialTokenProvider(t *testing.T) {
clientID = "fakeID"
clientSecret = "fakeSecret"
scope = "https://graph.microsoft.com/.default"
oboResponse = `{"token_type":"Bearer","expires_in":3599,"access_token":"%s"}`
oboResponse = `{"token_type":"Bearer","expires_on":1732881796,"access_token":"%s"}`
expectedContentType = "application/x-www-form-urlencoded"
expectedGrantType = "client_credentials"
expectedTokneType = "Bearer"
Expand Down
11 changes: 7 additions & 4 deletions auth/providers/azure/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ var (
)

const (
expiryDelta = 60 * time.Second
// Time delta to refresh token before expiry
tokenExpiryDelta = 300 * time.Second
getMemberGroupsTimeout = 23 * time.Second
getterName = "ms-graph"
arcAuthMode = "arc"
Expand Down Expand Up @@ -327,11 +328,13 @@ func (u *UserInfo) RefreshToken(ctx context.Context, token string) error {
if err != nil {
return errors.Errorf("%s: failed to refresh token: %s", u.tokenProvider.Name(), err)
}
klog.Infof("Token received, expires_at %d", resp.ExpiresOn)
// Set the authorization headers for future requests
u.headers.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token))
expIn := time.Duration(resp.Expires) * time.Second
u.expires = time.Now().Add(expIn - expiryDelta)
klog.Infof("Token refreshed successfully on %s. Expire at:%s", time.Now(), u.expires)
// Use ExpiresOn to set the expiration time
expOn := time.Unix(int64(resp.ExpiresOn), 0)
u.expires = expOn.Add(-tokenExpiryDelta)
klog.Infof("Token refreshed successfully at %s. Expire at set to: %s", time.Now(), u.expires)
}

return nil
Expand Down
16 changes: 11 additions & 5 deletions auth/providers/azure/graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ func TestLogin(t *testing.T) {
validToken := "blackbriar"
validBody := `{
"token_type": "Bearer",
"expires_in": 3599,
"access_token": "%s"
"access_token": "%s",
"expires_on": %d
}`
ts, u := getAuthServerAndUserInfo(http.StatusOK, fmt.Sprintf(validBody, validToken), "jason", "bourne")
expiresOn := time.Now().Add(time.Second * 3599)
ts, u := getAuthServerAndUserInfo(http.StatusOK, fmt.Sprintf(validBody, validToken, expiresOn.Unix()), "jason", "bourne")
defer ts.Close()

err := u.RefreshToken(ctx, "")
if err != nil {
t.Errorf("Error when trying to log in: %s", err)
Expand All @@ -122,6 +122,12 @@ func TestLogin(t *testing.T) {
if !time.Now().Before(u.expires) {
t.Errorf("Expiry not set properly. Expected it to be after the current time. Actual: %v", u.expires)
}
// Normalize to second precision for comparison
expectedExpiresOn := expiresOn.Add(-tokenExpiryDelta).Truncate(time.Second)
actualExpires := u.expires.Truncate(time.Second)
if !expectedExpiresOn.Equal(actualExpires) {
t.Errorf("Expiry not set properly. Expected it to be %v equal to expiresOn. Actual: %v", expectedExpiresOn, actualExpires)
}
})

t.Run("unsuccessful login", func(t *testing.T) {
Expand Down Expand Up @@ -494,7 +500,7 @@ func TestGetGroups(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/login", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
_, _ = w.Write([]byte(`{ "token_type": "Bearer", "expires_in": 8459, "access_token": "secret"}`))
_, _ = w.Write([]byte(`{ "token_type": "Bearer", "expires_on": 1732881796, "access_token": "secret"}`))
}))
mux.Handle("/users/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
Expand Down
5 changes: 3 additions & 2 deletions auth/providers/azure/graph/msi_tokenprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ func (u *msiTokenProvider) Acquire(ctx context.Context, token string) (AuthRespo
}

authResp.TokenType = tokenResp.TokenType
authResp.Expires, err = strconv.Atoi(tokenResp.ExpiresIn)
// This is the actual time the token expires in Unix time
authResp.ExpiresOn, err = strconv.Atoi(tokenResp.ExpiresOn)
if err != nil {
return authResp, errors.Wrapf(err, "Failed to decode expiry date")
return authResp, errors.Wrapf(err, "Failed to decode expires_on field for token")
}
authResp.Token = tokenResp.AccessToken

Expand Down
9 changes: 7 additions & 2 deletions auth/providers/azure/graph/msi_tokenprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ func TestMSITokenProvider(t *testing.T) {
const (
inputAccessToken = "inputAccessToken"
msiAccessToken = "msiAccessToken"
tokenResponse = `{"token_type":"Bearer","expires_in":"3599","access_token":"%s"}`
tokenResponse = `{"access_token":"%s","expires_in":"86700","refresh_token":"","expires_on":"%d","not_before":"1732795096","resource":"https://management.azure.com","token_type":"Bearer"}`
expectedContentType = "application/json"
expectedTokenType = "Bearer"
expectedExpiresOn = 1732881796
)

t.Run("Upon Success Response", func(t *testing.T) {
Expand All @@ -41,7 +42,7 @@ func TestMSITokenProvider(t *testing.T) {
if req.Header.Get("Content-Type") != expectedContentType {
t.Errorf("expected content type: %s, actual: %s", expectedContentType, req.Header.Get("Content-Type"))
}
_, _ = rw.Write([]byte(fmt.Sprintf(tokenResponse, msiAccessToken)))
_, _ = rw.Write([]byte(fmt.Sprintf(tokenResponse, msiAccessToken, expectedExpiresOn)))
})

defer stopMSITestServer(t, s)
Expand All @@ -59,6 +60,10 @@ func TestMSITokenProvider(t *testing.T) {
if resp.TokenType != expectedTokenType {
t.Errorf("expected token type: Bearer, actual: %s", resp.TokenType)
}

if resp.ExpiresOn != expectedExpiresOn {
t.Errorf("expected expires on: %d, actual: %d", expectedExpiresOn, resp.ExpiresOn)
}
})

t.Run("Upon Error Response", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion auth/providers/azure/graph/obo_tokenprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestOBOTokenProvider(t *testing.T) {
clientID = "fakeID"
clientSecret = "fakeSecret"
scope = "https://graph.microsoft.com/.default"
oboResponse = `{"token_type":"Bearer","expires_in":3599,"access_token":"%s"}`
oboResponse = `{"token_type":"Bearer","expires_on":1732881796,"access_token":"%s"}`
expectedContentType = "application/x-www-form-urlencoded"
expectedTokenUse = "on_behalf_of"
expectedGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
Expand Down
3 changes: 2 additions & 1 deletion auth/providers/azure/graph/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package graph
// AuthResponse represents a response from the MS Graph auth API
type AuthResponse struct {
TokenType string `json:"token_type"`
Expires int `json:"expires_in"`
Token string `json:"access_token"`
// This is the actual time the token expires on in Unix time
ExpiresOn int `json:"expires_on"`
}

// NOTE: These below are partial implementations of the API objects containing
Expand Down
2 changes: 1 addition & 1 deletion authz/providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import (
)

const (
loginResp = `{ "token_type": "Bearer", "expires_in": 8459, "access_token": "%v"}`
loginResp = `{ "token_type": "Bearer", "expires_on": 1732881796, "access_token": "%v"}`
httpClientRetryCount = 2
)

Expand Down
23 changes: 13 additions & 10 deletions authz/providers/azure/rbac/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ import (
)

const (
managedClusters = "Microsoft.ContainerService/managedClusters"
fleets = "Microsoft.ContainerService/fleets"
connectedClusters = "Microsoft.Kubernetes/connectedClusters"
checkAccessPath = "/providers/Microsoft.Authorization/checkaccess"
checkAccessAPIVersion = "2018-09-01-preview"
remainingSubReadARMHeader = "x-ms-ratelimit-remaining-subscription-reads"
expiryDelta = 60 * time.Second
managedClusters = "Microsoft.ContainerService/managedClusters"
fleets = "Microsoft.ContainerService/fleets"
connectedClusters = "Microsoft.Kubernetes/connectedClusters"
checkAccessPath = "/providers/Microsoft.Authorization/checkaccess"
checkAccessAPIVersion = "2018-09-01-preview"
remainingSubReadARMHeader = "x-ms-ratelimit-remaining-subscription-reads"
// Time delta to refresh token before expiry
tokenExpiryDelta = 300 * time.Second
checkaccessContextTimeout = 23 * time.Second
correlationRequestIDHeader = "x-ms-correlation-request-id"
)
Expand Down Expand Up @@ -224,9 +225,11 @@ func (a *AccessInfo) RefreshToken(ctx context.Context) error {

// Set the authorization headers for future requests
a.headers.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token))
expIn := time.Duration(resp.Expires) * time.Second
a.expiresAt = time.Now().Add(expIn - expiryDelta)
klog.Infof("Token refreshed successfully on %s. Expire at:%s", time.Now(), a.expiresAt)

// Use ExpiresOn to set the expiration time
expOn := time.Unix(int64(resp.ExpiresOn), 0)
a.expiresAt = expOn.Add(-tokenExpiryDelta)
klog.Infof("Token refreshed successfully at %s. Expire at set to: %s", time.Now(), a.expiresAt)
}

return nil
Expand Down
14 changes: 11 additions & 3 deletions authz/providers/azure/rbac/rbac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,11 @@ func TestLogin(t *testing.T) {
validToken := "blackbriar"
validBody := `{
"token_type": "Bearer",
"expires_in": 3599,
"access_token": "%s"
"access_token": "%s",
"expires_on": %d
}`
ts, u := getAuthServerAndAccessInfo(http.StatusOK, fmt.Sprintf(validBody, validToken), "jason", "bourne")
expiresOn := time.Now().Add(time.Second * 3599)
ts, u := getAuthServerAndAccessInfo(http.StatusOK, fmt.Sprintf(validBody, validToken, expiresOn.Unix()), "jason", "bourne")
defer ts.Close()

ctx := context.Background()
Expand All @@ -194,6 +195,13 @@ func TestLogin(t *testing.T) {
if !time.Now().Before(u.expiresAt) {
t.Errorf("Expiry not set properly. Expected it to be after the current time. Actual: %v", u.expiresAt)
}

// Normalize to second precision for comparison
expectedExpiresOn := expiresOn.Add(-tokenExpiryDelta).Truncate(time.Second)
actualExpires := u.expiresAt.Truncate(time.Second)
if !expectedExpiresOn.Equal(actualExpires) {
t.Errorf("Expiry not set properly. Expected it to be %v equal to expiresOn. Actual: %v", expectedExpiresOn, actualExpires)
}
})

t.Run("unsuccessful login", func(t *testing.T) {
Expand Down

0 comments on commit b512fe6

Please sign in to comment.