diff --git a/apps/internal/base/internal/storage/storage.go b/apps/internal/base/internal/storage/storage.go index 11263822..ed82397f 100644 --- a/apps/internal/base/internal/storage/storage.go +++ b/apps/internal/base/internal/storage/storage.go @@ -82,6 +82,39 @@ func isMatchingScopes(scopesOne []string, scopesTwo string) bool { return scopeCounter == len(scopesOne) } +// needsUpgrade returns true if the given key follows the v1.0 schema i.e., +// it contains an uppercase character (v1.1+ keys are all lowercase) +func needsUpgrade(key string) bool { + for _, r := range key { + if 'A' <= r && r <= 'Z' { + return true + } + } + return false +} + +// upgrade a v1.0 cache item by adding a v1.1+ item having the same value and deleting +// the v1.0 item. Callers must hold an exclusive lock on m. +func upgrade[T any](m map[string]T, k string) T { + v1_1Key := strings.ToLower(k) + v, ok := m[k] + if !ok { + // another goroutine did the upgrade while this one was waiting for the write lock + return m[v1_1Key] + } + if v2, ok := m[v1_1Key]; ok { + // cache has an equivalent v1.1+ item, which we prefer because we know it was added + // by a newer version of the module and is therefore more likely to remain valid. + // The v1.0 item may have expired because only v1.0 or earlier would update it. + v = v2 + } else { + // add an equivalent item according to the v1.1 schema + m[v1_1Key] = v + } + delete(m, k) + return v +} + // Read reads a storage token from the cache if it exists. func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) { tr := TokenResponse{} @@ -255,21 +288,25 @@ func (m *Manager) aadMetadata(ctx context.Context, authorityInfo authority.Info) func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, clientID string, scopes []string, tokenType, authnSchemeKeyID string) AccessToken { m.contractMu.RLock() - defer m.contractMu.RUnlock() // TODO: linear search (over a map no less) is slow for a large number (thousands) of tokens. // this shows up as the dominating node in a profile. for real-world scenarios this likely isn't // an issue, however if it does become a problem then we know where to look. - for _, at := range m.contract.AccessTokens { + for k, at := range m.contract.AccessTokens { if at.HomeAccountID == homeID && at.Realm == realm && at.ClientID == clientID { if (at.TokenType == tokenType && at.AuthnSchemeKeyID == authnSchemeKeyID) || (at.TokenType == "" && (tokenType == "" || tokenType == "Bearer")) { - if checkAlias(at.Environment, envAliases) { - if isMatchingScopes(scopes, at.Scopes) { - return at + if checkAlias(at.Environment, envAliases) && isMatchingScopes(scopes, at.Scopes) { + m.contractMu.RUnlock() + if needsUpgrade(k) { + m.contractMu.Lock() + defer m.contractMu.Unlock() + at = upgrade(m.contract.AccessTokens, k) } + return at } } } } + m.contractMu.RUnlock() return AccessToken{} } @@ -310,15 +347,21 @@ func (m *Manager) readRefreshToken(homeID string, envAliases []string, familyID, // If app is part of the family or if we DO NOT KNOW if it's part of the family, search by family ID, then by client_id (we will know if an app is part of the family after the first token response). // https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/311fe8b16e7c293462806f397e189a6aa1159769/src/client/Microsoft.Identity.Client/Internal/Requests/Silent/CacheSilentStrategy.cs#L95 m.contractMu.RLock() - defer m.contractMu.RUnlock() for _, matcher := range matchers { - for _, rt := range m.contract.RefreshTokens { + for k, rt := range m.contract.RefreshTokens { if matcher(rt) { + m.contractMu.RUnlock() + if needsUpgrade(k) { + m.contractMu.Lock() + defer m.contractMu.Unlock() + rt = upgrade(m.contract.RefreshTokens, k) + } return rt, nil } } } + m.contractMu.RUnlock() return accesstokens.RefreshToken{}, fmt.Errorf("refresh token not found") } @@ -340,14 +383,20 @@ func (m *Manager) writeRefreshToken(refreshToken accesstokens.RefreshToken) erro func (m *Manager) readIDToken(homeID string, envAliases []string, realm, clientID string) (IDToken, error) { m.contractMu.RLock() - defer m.contractMu.RUnlock() - for _, idt := range m.contract.IDTokens { + for k, idt := range m.contract.IDTokens { if idt.HomeAccountID == homeID && idt.Realm == realm && idt.ClientID == clientID { if checkAlias(idt.Environment, envAliases) { + m.contractMu.RUnlock() + if needsUpgrade(k) { + m.contractMu.Lock() + defer m.contractMu.Unlock() + idt = upgrade(m.contract.IDTokens, k) + } return idt, nil } } } + m.contractMu.RUnlock() return IDToken{}, fmt.Errorf("token not found") } @@ -386,7 +435,6 @@ func (m *Manager) Account(homeAccountID string) shared.Account { func (m *Manager) readAccount(homeAccountID string, envAliases []string, realm string) (shared.Account, error) { m.contractMu.RLock() - defer m.contractMu.RUnlock() // You might ask why, if cache.Accounts is a map, we would loop through all of these instead of using a key. // We only use a map because the storage contract shared between all language implementations says use a map. @@ -394,11 +442,18 @@ func (m *Manager) readAccount(homeAccountID string, envAliases []string, realm s // a match in multiple envs (envAlias). That means we either need to hash each possible keyand do the lookup // or just statically check. Since the design is to have a storage.Manager per user, the amount of keys stored // is really low (say 2). Each hash is more expensive than the entire iteration. - for _, acc := range m.contract.Accounts { + for k, acc := range m.contract.Accounts { if acc.HomeAccountID == homeAccountID && checkAlias(acc.Environment, envAliases) && acc.Realm == realm { + m.contractMu.RUnlock() + if needsUpgrade(k) { + m.contractMu.Lock() + defer m.contractMu.Unlock() + acc = upgrade(m.contract.Accounts, k) + } return acc, nil } } + m.contractMu.RUnlock() return shared.Account{}, fmt.Errorf("account not found") } @@ -412,13 +467,18 @@ func (m *Manager) writeAccount(account shared.Account) error { func (m *Manager) readAppMetaData(envAliases []string, clientID string) (AppMetaData, error) { m.contractMu.RLock() - defer m.contractMu.RUnlock() - - for _, app := range m.contract.AppMetaData { + for k, app := range m.contract.AppMetaData { if checkAlias(app.Environment, envAliases) && app.ClientID == clientID { + m.contractMu.RUnlock() + if needsUpgrade(k) { + m.contractMu.Lock() + defer m.contractMu.Unlock() + app = upgrade(m.contract.AppMetaData, k) + } return app, nil } } + m.contractMu.RUnlock() return AppMetaData{}, fmt.Errorf("not found") } diff --git a/apps/internal/base/internal/storage/storage_test.go b/apps/internal/base/internal/storage/storage_test.go index 320adafc..0570115c 100644 --- a/apps/internal/base/internal/storage/storage_test.go +++ b/apps/internal/base/internal/storage/storage_test.go @@ -9,6 +9,7 @@ import ( "os" "reflect" "sort" + "strings" "testing" "time" @@ -115,6 +116,106 @@ func TestAllAccounts(t *testing.T) { } } +func TestSchemaUpgrade(t *testing.T) { + countV1Keys := func(mgr *Manager) int { + v1Keys := 0 + for k := range mgr.contract.AccessTokens { + if strings.ToLower(k) != k { + v1Keys++ + } + } + for k := range mgr.contract.AppMetaData { + if strings.ToLower(k) != k { + v1Keys++ + } + } + for k := range mgr.contract.IDTokens { + if strings.ToLower(k) != k { + v1Keys++ + } + } + for k := range mgr.contract.RefreshTokens { + if strings.ToLower(k) != k { + v1Keys++ + } + } + return v1Keys + } + + for _, test := range []struct { + desc, file string + shouldRemoveAllV1Keys bool + }{ + { + desc: "v1.0 cache", + file: "testdata/v1.0_cache.json", + shouldRemoveAllV1Keys: true, + }, + { + desc: "cache shared by v1.0 and v1.1", + file: "testdata/v1.0_v1.1_cache.json", + }, + } { + t.Run(test.desc, func(t *testing.T) { + mgr := newForTest(&fakeDiscoveryResponser{ + ret: authority.InstanceDiscoveryResponse{ + Metadata: []authority.InstanceDiscoveryMetadata{ + {Aliases: []string{defaultEnvironment}}, + }, + }, + }) + b, err := os.ReadFile(test.file) + if err != nil { + t.Fatal(err) + } + err = mgr.Unmarshal(b) + if err != nil { + t.Fatal(err) + } + + before := countV1Keys(mgr) + if before == 0 { + t.Fatal("test bug: expected to have some v1.0 (mixed case) keys before Read") + } + tr, err := mgr.Read(context.Background(), authority.AuthParams{ + AuthnScheme: &authority.BearerAuthenticationScheme{}, + AuthorityInfo: authority.Info{ + Host: defaultEnvironment, + Tenant: defaultRealm, + }, + ClientID: defaultClientID, + HomeAccountID: defaultHID, + Scopes: strings.Split(defaultScopes, " "), + }) + if err != nil { + t.Fatal(err) + } + if tr.Account.LocalAccountID != accLID { + t.Errorf("expected local ID %q, got %q", accLID, tr.Account.LocalAccountID) + } + if tr.AccessToken.Secret != accessTokenSecret { + t.Errorf("expected access token %q, got %q", accessTokenSecret, tr.AccessToken.Secret) + } + if tr.IDToken.Secret != idSecret { + t.Errorf("expected ID token %q, got %q", idSecret, tr.IDToken.Secret) + } + if tr.RefreshToken.Secret != rtSecret { + t.Errorf("expected refresh token %q, got %q", rtSecret, tr.RefreshToken.Secret) + } + after := countV1Keys(mgr) + if test.shouldRemoveAllV1Keys { + if after != 0 { + t.Fatal("expected to have no v1.0 (mixed case) keys after Read") + } + } else if after >= before { + // we can't predict how many keys will be removed because this depends on the + // iteration order of the Manager's maps, which isn't specified or stabale + t.Fatal("Read should have removed some v1.0 (mixed case) keys") + } + }) + } +} + func TestReadAccessToken(t *testing.T) { now := time.Now() // Tokeb with token type diff --git a/apps/internal/base/internal/storage/testdata/v1.0_cache.json b/apps/internal/base/internal/storage/testdata/v1.0_cache.json new file mode 100644 index 00000000..ae54c935 --- /dev/null +++ b/apps/internal/base/internal/storage/testdata/v1.0_cache.json @@ -0,0 +1,52 @@ +{ + "Account": { + "uid.utid-login.windows.net-Contoso": { + "username": "John Doe", + "local_account_id": "object1234", + "realm": "contoso", + "environment": "login.windows.net", + "home_account_id": "uid.utid", + "authority_type": "MSSTS" + } + }, + "RefreshToken": { + "uid.utid-login.windows.net-RefreshToken-my_client_id--s2 s1 s3": { + "target": "s2 s1 s3", + "environment": "login.windows.net", + "credential_type": "RefreshToken", + "secret": "a refresh token", + "client_id": "my_client_id", + "home_account_id": "uid.utid" + } + }, + "AccessToken": { + "uid.utid-login.windows.net-AccessToken-my_client_id-contoso-s2 s1 s3": { + "environment": "login.windows.net", + "credential_type": "AccessToken", + "secret": "an access token", + "realm": "contoso", + "target": "s2 s1 s3", + "client_id": "my_client_id", + "cached_at": "1000", + "home_account_id": "uid.utid", + "extended_expires_on": "4600", + "expires_on": "4600" + } + }, + "IdToken": { + "uid.utid-login.windows.net-IdToken-my_client_id-contoso-": { + "realm": "contoso", + "environment": "login.windows.net", + "credential_type": "IdToken", + "secret": "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature", + "client_id": "my_client_id", + "home_account_id": "uid.utid" + } + }, + "AppMetadata": { + "AppMetadata-login.windows.net-my_client_id": { + "environment": "login.windows.net", + "client_id": "my_client_id" + } + } +} diff --git a/apps/internal/base/internal/storage/testdata/v1.0_v1.1_cache.json b/apps/internal/base/internal/storage/testdata/v1.0_v1.1_cache.json new file mode 100644 index 00000000..f031d42b --- /dev/null +++ b/apps/internal/base/internal/storage/testdata/v1.0_v1.1_cache.json @@ -0,0 +1,92 @@ +{ + "Account": { + "uid.utid-login.windows.net-Contoso": { + "username": "John Doe", + "local_account_id": "wrong value", + "realm": "contoso", + "environment": "login.windows.net", + "home_account_id": "uid.utid", + "authority_type": "MSSTS" + }, + "uid.utid-login.windows.net-contoso": { + "username": "John Doe", + "local_account_id": "object1234", + "realm": "contoso", + "environment": "login.windows.net", + "home_account_id": "uid.utid", + "authority_type": "MSSTS" + } + }, + "RefreshToken": { + "uid.utid-login.windows.net-RefreshToken-my_client_id--s2 s1 s3": { + "target": "s2 s1 s3", + "environment": "login.windows.net", + "credential_type": "RefreshToken", + "secret": "wrong value", + "client_id": "my_client_id", + "home_account_id": "uid.utid" + }, + "uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": { + "target": "s2 s1 s3", + "environment": "login.windows.net", + "credential_type": "RefreshToken", + "secret": "a refresh token", + "client_id": "my_client_id", + "home_account_id": "uid.utid" + } + }, + "AccessToken": { + "uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": { + "environment": "login.windows.net", + "credential_type": "AccessToken", + "secret": "an access token", + "realm": "contoso", + "target": "s2 s1 s3", + "client_id": "my_client_id", + "cached_at": "1000", + "home_account_id": "uid.utid", + "extended_expires_on": "4600", + "expires_on": "4600" + }, + "uid.utid-login.windows.net-AccessToken-my_client_id-contoso-s2 s1 s3": { + "environment": "login.windows.net", + "credential_type": "AccessToken", + "secret": "wrong value", + "realm": "contoso", + "target": "s2 s1 s3", + "client_id": "my_client_id", + "cached_at": "1000", + "home_account_id": "uid.utid", + "extended_expires_on": "4600", + "expires_on": "4600" + } + }, + "IdToken": { + "uid.utid-login.windows.net-IdToken-my_client_id-contoso-": { + "realm": "contoso", + "environment": "login.windows.net", + "credential_type": "IdToken", + "secret": "wrong value", + "client_id": "my_client_id", + "home_account_id": "uid.utid" + }, + "uid.utid-login.windows.net-idtoken-my_client_id-contoso-": { + "realm": "contoso", + "environment": "login.windows.net", + "credential_type": "IdToken", + "secret": "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature", + "client_id": "my_client_id", + "home_account_id": "uid.utid" + } + }, + "AppMetadata": { + "AppMetadata-login.windows.net-my_client_id": { + "environment": "login.windows.net", + "client_id": "my_client_id" + }, + "appmetadata-login.windows.net-my_client_id": { + "environment": "login.windows.net", + "client_id": "my_client_id" + } + } +}