Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce GetToken calls on TokenStoreCredentials #2503

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions common/oauthTokenManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand Down Expand Up @@ -522,10 +523,39 @@ func getAuthorityURL(tenantID, activeDirectoryEndpoint string) (*url.URL, error)
return u.Parse(tenantID)
}

const minimumTokenValidDuration = time.Minute * 5
type TokenStoreCredential struct {
token *azcore.AccessToken
lock sync.RWMutex
}

// globalTokenStoreCredential is created to make sure that all
// service clients share same cred object. This is required so that
// we do not make repeated GetToken calls.
// This is a temporary fix for issue where we would request a
// new token from Stg Exp even while they've not yet populated the
// tokenstore.
//
// This is okay because we use same credential on both source and
// destination. If we move to a case where the credentials are
// different, this should be removed.
//
// We should move to a method where the token is always read from
// tokenstore, and azcopy is invoked after tokenstore is populated.
//
var globalTokenStoreCredential *TokenStoreCredential
var globalTsc sync.Once

func (tsc *TokenStoreCredential) GetToken(_ context.Context, _ policy.TokenRequestOptions) (azcore.AccessToken, error) {
// if the token we've has not expired, return the same.
tsc.lock.RLock()
if time.Until(tsc.token.ExpiresOn) > minimumTokenValidDuration {
return *tsc.token, nil
}
tsc.lock.RUnlock()

tsc.lock.Lock()
defer tsc.lock.Unlock()
hasToken, err := tokenStoreCredCache.HasCachedToken()
if err != nil || !hasToken {
return azcore.AccessToken{}, fmt.Errorf("no cached token found in Token Store Mode(SE), %v", err)
Expand All @@ -536,19 +566,32 @@ func (tsc *TokenStoreCredential) GetToken(_ context.Context, _ policy.TokenReque
return azcore.AccessToken{}, fmt.Errorf("get cached token failed in Token Store Mode(SE), %v", err)
}

return azcore.AccessToken{
tsc.token = &azcore.AccessToken{
Token: tokenInfo.AccessToken,
ExpiresOn: tokenInfo.Expires(),
}, nil
}

return *tsc.token, nil

}

// GetNewTokenFromTokenStore gets token from token store. (Credential Manager in Windows, keyring in Linux and keychain in MacOS.)
// Note: This approach should only be used in internal integrations.
func GetTokenStoreCredential(accessToken string, expiresOn time.Time) (azcore.TokenCredential) {
globalTsc.Do(func() {
globalTokenStoreCredential = &TokenStoreCredential{
token: &azcore.AccessToken{
Token: accessToken,
ExpiresOn: expiresOn,
},
}
})
return globalTokenStoreCredential
}

func (credInfo *OAuthTokenInfo) GetTokenStoreCredential() (azcore.TokenCredential, error) {
tc := &TokenStoreCredential{}
credInfo.TokenCredential = tc
return tc, nil
credInfo.TokenCredential = GetTokenStoreCredential(credInfo.AccessToken, credInfo.Expires())
return credInfo.TokenCredential, nil
}

func (credInfo *OAuthTokenInfo) GetManagedIdentityCredential() (azcore.TokenCredential, error) {
Expand Down
Loading