Skip to content

Commit

Permalink
Upgrade v1.0 cache items (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Sep 6, 2023
1 parent 0b07e1c commit c3591af
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 14 deletions.
88 changes: 74 additions & 14 deletions apps/internal/base/internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ func isMatchingScopes(scopesOne []string, scopesTwo string) bool {
return scopeCounter == len(scopesOne)
}

// needsUpgrade returns true if the given key follows the v1.0 schema i.e.,
// it contains an uppercase character (v1.1+ keys are all lowercase)
func needsUpgrade(key string) bool {
for _, r := range key {
if 'A' <= r && r <= 'Z' {
return true
}
}
return false
}

// upgrade a v1.0 cache item by adding a v1.1+ item having the same value and deleting
// the v1.0 item. Callers must hold an exclusive lock on m.
func upgrade[T any](m map[string]T, k string) T {
v1_1Key := strings.ToLower(k)
v, ok := m[k]
if !ok {
// another goroutine did the upgrade while this one was waiting for the write lock
return m[v1_1Key]
}
if v2, ok := m[v1_1Key]; ok {
// cache has an equivalent v1.1+ item, which we prefer because we know it was added
// by a newer version of the module and is therefore more likely to remain valid.
// The v1.0 item may have expired because only v1.0 or earlier would update it.
v = v2
} else {
// add an equivalent item according to the v1.1 schema
m[v1_1Key] = v
}
delete(m, k)
return v
}

// Read reads a storage token from the cache if it exists.
func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) {
tr := TokenResponse{}
Expand Down Expand Up @@ -255,21 +288,25 @@ func (m *Manager) aadMetadata(ctx context.Context, authorityInfo authority.Info)

func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, clientID string, scopes []string, tokenType, authnSchemeKeyID string) AccessToken {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
// TODO: linear search (over a map no less) is slow for a large number (thousands) of tokens.
// this shows up as the dominating node in a profile. for real-world scenarios this likely isn't
// an issue, however if it does become a problem then we know where to look.
for _, at := range m.contract.AccessTokens {
for k, at := range m.contract.AccessTokens {
if at.HomeAccountID == homeID && at.Realm == realm && at.ClientID == clientID {
if (at.TokenType == tokenType && at.AuthnSchemeKeyID == authnSchemeKeyID) || (at.TokenType == "" && (tokenType == "" || tokenType == "Bearer")) {
if checkAlias(at.Environment, envAliases) {
if isMatchingScopes(scopes, at.Scopes) {
return at
if checkAlias(at.Environment, envAliases) && isMatchingScopes(scopes, at.Scopes) {
m.contractMu.RUnlock()
if needsUpgrade(k) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
at = upgrade(m.contract.AccessTokens, k)
}
return at
}
}
}
}
m.contractMu.RUnlock()
return AccessToken{}
}

Expand Down Expand Up @@ -310,15 +347,21 @@ func (m *Manager) readRefreshToken(homeID string, envAliases []string, familyID,
// If app is part of the family or if we DO NOT KNOW if it's part of the family, search by family ID, then by client_id (we will know if an app is part of the family after the first token response).
// https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/311fe8b16e7c293462806f397e189a6aa1159769/src/client/Microsoft.Identity.Client/Internal/Requests/Silent/CacheSilentStrategy.cs#L95
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, matcher := range matchers {
for _, rt := range m.contract.RefreshTokens {
for k, rt := range m.contract.RefreshTokens {
if matcher(rt) {
m.contractMu.RUnlock()
if needsUpgrade(k) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
rt = upgrade(m.contract.RefreshTokens, k)
}
return rt, nil
}
}
}

m.contractMu.RUnlock()
return accesstokens.RefreshToken{}, fmt.Errorf("refresh token not found")
}

Expand All @@ -340,14 +383,20 @@ func (m *Manager) writeRefreshToken(refreshToken accesstokens.RefreshToken) erro

func (m *Manager) readIDToken(homeID string, envAliases []string, realm, clientID string) (IDToken, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, idt := range m.contract.IDTokens {
for k, idt := range m.contract.IDTokens {
if idt.HomeAccountID == homeID && idt.Realm == realm && idt.ClientID == clientID {
if checkAlias(idt.Environment, envAliases) {
m.contractMu.RUnlock()
if needsUpgrade(k) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
idt = upgrade(m.contract.IDTokens, k)
}
return idt, nil
}
}
}
m.contractMu.RUnlock()
return IDToken{}, fmt.Errorf("token not found")
}

Expand Down Expand Up @@ -386,19 +435,25 @@ func (m *Manager) Account(homeAccountID string) shared.Account {

func (m *Manager) readAccount(homeAccountID string, envAliases []string, realm string) (shared.Account, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()

// You might ask why, if cache.Accounts is a map, we would loop through all of these instead of using a key.
// We only use a map because the storage contract shared between all language implementations says use a map.
// We can't change that. The other is because the keys are made using a specific "env", but here we are allowing
// a match in multiple envs (envAlias). That means we either need to hash each possible keyand do the lookup
// or just statically check. Since the design is to have a storage.Manager per user, the amount of keys stored
// is really low (say 2). Each hash is more expensive than the entire iteration.
for _, acc := range m.contract.Accounts {
for k, acc := range m.contract.Accounts {
if acc.HomeAccountID == homeAccountID && checkAlias(acc.Environment, envAliases) && acc.Realm == realm {
m.contractMu.RUnlock()
if needsUpgrade(k) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
acc = upgrade(m.contract.Accounts, k)
}
return acc, nil
}
}
m.contractMu.RUnlock()
return shared.Account{}, fmt.Errorf("account not found")
}

Expand All @@ -412,13 +467,18 @@ func (m *Manager) writeAccount(account shared.Account) error {

func (m *Manager) readAppMetaData(envAliases []string, clientID string) (AppMetaData, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()

for _, app := range m.contract.AppMetaData {
for k, app := range m.contract.AppMetaData {
if checkAlias(app.Environment, envAliases) && app.ClientID == clientID {
m.contractMu.RUnlock()
if needsUpgrade(k) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
app = upgrade(m.contract.AppMetaData, k)
}
return app, nil
}
}
m.contractMu.RUnlock()
return AppMetaData{}, fmt.Errorf("not found")
}

Expand Down
101 changes: 101 additions & 0 deletions apps/internal/base/internal/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"reflect"
"sort"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -115,6 +116,106 @@ func TestAllAccounts(t *testing.T) {
}
}

func TestSchemaUpgrade(t *testing.T) {
countV1Keys := func(mgr *Manager) int {
v1Keys := 0
for k := range mgr.contract.AccessTokens {
if strings.ToLower(k) != k {
v1Keys++
}
}
for k := range mgr.contract.AppMetaData {
if strings.ToLower(k) != k {
v1Keys++
}
}
for k := range mgr.contract.IDTokens {
if strings.ToLower(k) != k {
v1Keys++
}
}
for k := range mgr.contract.RefreshTokens {
if strings.ToLower(k) != k {
v1Keys++
}
}
return v1Keys
}

for _, test := range []struct {
desc, file string
shouldRemoveAllV1Keys bool
}{
{
desc: "v1.0 cache",
file: "testdata/v1.0_cache.json",
shouldRemoveAllV1Keys: true,
},
{
desc: "cache shared by v1.0 and v1.1",
file: "testdata/v1.0_v1.1_cache.json",
},
} {
t.Run(test.desc, func(t *testing.T) {
mgr := newForTest(&fakeDiscoveryResponser{
ret: authority.InstanceDiscoveryResponse{
Metadata: []authority.InstanceDiscoveryMetadata{
{Aliases: []string{defaultEnvironment}},
},
},
})
b, err := os.ReadFile(test.file)
if err != nil {
t.Fatal(err)
}
err = mgr.Unmarshal(b)
if err != nil {
t.Fatal(err)
}

before := countV1Keys(mgr)
if before == 0 {
t.Fatal("test bug: expected to have some v1.0 (mixed case) keys before Read")
}
tr, err := mgr.Read(context.Background(), authority.AuthParams{
AuthnScheme: &authority.BearerAuthenticationScheme{},
AuthorityInfo: authority.Info{
Host: defaultEnvironment,
Tenant: defaultRealm,
},
ClientID: defaultClientID,
HomeAccountID: defaultHID,
Scopes: strings.Split(defaultScopes, " "),
})
if err != nil {
t.Fatal(err)
}
if tr.Account.LocalAccountID != accLID {
t.Errorf("expected local ID %q, got %q", accLID, tr.Account.LocalAccountID)
}
if tr.AccessToken.Secret != accessTokenSecret {
t.Errorf("expected access token %q, got %q", accessTokenSecret, tr.AccessToken.Secret)
}
if tr.IDToken.Secret != idSecret {
t.Errorf("expected ID token %q, got %q", idSecret, tr.IDToken.Secret)
}
if tr.RefreshToken.Secret != rtSecret {
t.Errorf("expected refresh token %q, got %q", rtSecret, tr.RefreshToken.Secret)
}
after := countV1Keys(mgr)
if test.shouldRemoveAllV1Keys {
if after != 0 {
t.Fatal("expected to have no v1.0 (mixed case) keys after Read")
}
} else if after >= before {
// we can't predict how many keys will be removed because this depends on the
// iteration order of the Manager's maps, which isn't specified or stabale
t.Fatal("Read should have removed some v1.0 (mixed case) keys")
}
})
}
}

func TestReadAccessToken(t *testing.T) {
now := time.Now()
// Tokeb with token type
Expand Down
52 changes: 52 additions & 0 deletions apps/internal/base/internal/storage/testdata/v1.0_cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"Account": {
"uid.utid-login.windows.net-Contoso": {
"username": "John Doe",
"local_account_id": "object1234",
"realm": "contoso",
"environment": "login.windows.net",
"home_account_id": "uid.utid",
"authority_type": "MSSTS"
}
},
"RefreshToken": {
"uid.utid-login.windows.net-RefreshToken-my_client_id--s2 s1 s3": {
"target": "s2 s1 s3",
"environment": "login.windows.net",
"credential_type": "RefreshToken",
"secret": "a refresh token",
"client_id": "my_client_id",
"home_account_id": "uid.utid"
}
},
"AccessToken": {
"uid.utid-login.windows.net-AccessToken-my_client_id-contoso-s2 s1 s3": {
"environment": "login.windows.net",
"credential_type": "AccessToken",
"secret": "an access token",
"realm": "contoso",
"target": "s2 s1 s3",
"client_id": "my_client_id",
"cached_at": "1000",
"home_account_id": "uid.utid",
"extended_expires_on": "4600",
"expires_on": "4600"
}
},
"IdToken": {
"uid.utid-login.windows.net-IdToken-my_client_id-contoso-": {
"realm": "contoso",
"environment": "login.windows.net",
"credential_type": "IdToken",
"secret": "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature",
"client_id": "my_client_id",
"home_account_id": "uid.utid"
}
},
"AppMetadata": {
"AppMetadata-login.windows.net-my_client_id": {
"environment": "login.windows.net",
"client_id": "my_client_id"
}
}
}
Loading

0 comments on commit c3591af

Please sign in to comment.