From 4c397f807f29ed0b976824a107966bfa095dd3d8 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 13 Apr 2023 15:11:11 -0700 Subject: [PATCH] Prevent persistent cache data races (#402) --- apps/internal/base/base.go | 234 +++++++----------- apps/internal/base/base_test.go | 68 ++++- .../internal/base/internal/storage/storage.go | 9 +- .../base/internal/storage/storage_test.go | 2 +- 4 files changed, 154 insertions(+), 159 deletions(-) diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 00617abf..5f68384f 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -10,6 +10,7 @@ import ( "net/url" "reflect" "strings" + "sync" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" @@ -27,31 +28,21 @@ const ( ) // manager provides an internal cache. It is defined to allow faking the cache in tests. -// In all production use it is a *storage.Manager. +// In production it's a *storage.Manager or *storage.PartitionedManager. type manager interface { - Read(ctx context.Context, authParameters authority.AuthParams, account shared.Account) (storage.TokenResponse, error) - Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) + cache.Serializer + Read(context.Context, authority.AuthParams) (storage.TokenResponse, error) + Write(authority.AuthParams, accesstokens.TokenResponse) (shared.Account, error) +} + +// accountManager is a manager that also caches accounts. In production it's a *storage.Manager. +type accountManager interface { + manager AllAccounts() []shared.Account Account(homeAccountID string) shared.Account RemoveAccount(account shared.Account, clientID string) } -// partitionedManager provides an internal cache. It is defined to allow faking the cache in tests. -// In all production use it is a *storage.PartitionedManager. -type partitionedManager interface { - Read(ctx context.Context, authParameters authority.AuthParams) (storage.TokenResponse, error) - Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) -} - -type noopCacheAccessor struct{} - -func (n noopCacheAccessor) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { - return nil -} -func (n noopCacheAccessor) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { - return nil -} - // AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache). type AcquireTokenSilentParameters struct { Scopes []string @@ -137,12 +128,14 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco // Client is a base client that provides access to common methods and primatives that // can be used by multiple clients. type Client struct { - Token *oauth.Client - manager manager // *storage.Manager or fakeManager in tests - pmanager partitionedManager // *storage.PartitionedManager or fakeManager in tests - - AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). - cacheAccessor cache.ExportReplace + Token *oauth.Client + manager accountManager // *storage.Manager or fakeManager in tests + // pmanager is a partitioned cache for OBO authentication. *storage.PartitionedManager or fakeManager in tests + pmanager manager + + AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). + cacheAccessor cache.ExportReplace + cacheAccessorMu *sync.RWMutex } // Option is an optional argument to the New constructor. @@ -214,11 +207,11 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O } authParams := authority.NewAuthParams(clientID, authInfo) client := Client{ // Note: Hey, don't even THINK about making Base into *Base. See "design notes" in public.go and confidential.go - Token: token, - AuthParams: authParams, - cacheAccessor: noopCacheAccessor{}, - manager: storage.New(token), - pmanager: storage.NewPartitionedManager(token), + Token: token, + AuthParams: authParams, + cacheAccessorMu: &sync.RWMutex{}, + manager: storage.New(token), + pmanager: storage.NewPartitionedManager(token), } for _, o := range options { if err = o(&client); err != nil { @@ -283,8 +276,9 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s return baseURL.String(), nil } -func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (ar AuthResult, err error) { - // when tenant == "", the caller didn't specify a tenant and WithTenant will use the client's configured tenant +func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) { + ar := AuthResult{} + // when tenant == "", the caller didn't specify a tenant and WithTenant will choose the client's configured tenant tenant := silent.TenantID authParams, err := b.AuthParams.WithTenant(tenant) if err != nil { @@ -296,38 +290,23 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen authParams.Claims = silent.Claims authParams.UserAssertion = silent.UserAssertion - var storageTokenResponse storage.TokenResponse - if authParams.AuthorizationType == authority.ATOnBehalfOf { - if s, ok := b.pmanager.(cache.Serializer); ok { - suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return ar, err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() - } - storageTokenResponse, err = b.pmanager.Read(ctx, authParams) - if err != nil { - return ar, err - } - } else { - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return ar, err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() - } + m := b.pmanager + if authParams.AuthorizationType != authority.ATOnBehalfOf { authParams.AuthorizationType = authority.ATRefreshToken - storageTokenResponse, err = b.manager.Read(ctx, authParams, silent.Account) - if err != nil { - return ar, err - } + m = b.manager + } + if b.cacheAccessor != nil { + key := authParams.CacheKey(silent.IsAppCache) + b.cacheAccessorMu.RLock() + err = b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key}) + b.cacheAccessorMu.RUnlock() + } + if err != nil { + return ar, err + } + storageTokenResponse, err := m.Read(ctx, authParams) + if err != nil { + return ar, err } // ignore cached access tokens when given claims @@ -340,21 +319,17 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen // redeem a cached refresh token, if available if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() { - err = errors.New("no token found") - return ar, err + return ar, errors.New("no token found") } var cc *accesstokens.Credential if silent.RequestType == accesstokens.ATConfidential { cc = silent.Credential } - token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken) if err != nil { return ar, err } - - ar, err = b.AuthResultFromToken(ctx, authParams, token, true) - return ar, err + return b.AuthResultFromToken(ctx, authParams, token, true) } func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) { @@ -417,103 +392,76 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq return ar, err } -func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (ar AuthResult, err error) { +func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) { if !cacheWrite { return NewAuthResult(token, shared.Account{}) } - - var account shared.Account + var m manager = b.manager if authParams.AuthorizationType == authority.ATOnBehalfOf { - if s, ok := b.pmanager.(cache.Serializer); ok { - suggestedCacheKey := token.CacheKey(authParams) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return ar, err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() - } - account, err = b.pmanager.Write(authParams, token) + m = b.pmanager + } + key := token.CacheKey(authParams) + if b.cacheAccessor != nil { + b.cacheAccessorMu.Lock() + defer b.cacheAccessorMu.Unlock() + err := b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key}) if err != nil { - return ar, err - } - } else { - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := token.CacheKey(authParams) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return ar, err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() - } - account, err = b.manager.Write(authParams, token) - if err != nil { - return ar, err + return AuthResult{}, err } } - ar, err = NewAuthResult(token, account) + account, err := m.Write(authParams, token) + if err != nil { + return AuthResult{}, err + } + ar, err := NewAuthResult(token, account) + if err == nil && b.cacheAccessor != nil { + err = b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key}) + } return ar, err } -func (b Client) AllAccounts(ctx context.Context) (accts []shared.Account, err error) { - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) +func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) { + if b.cacheAccessor != nil { + b.cacheAccessorMu.RLock() + defer b.cacheAccessorMu.RUnlock() + key := b.AuthParams.CacheKey(false) + err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { - return accts, err + return nil, err } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() } - - accts = b.manager.AllAccounts() - return accts, err + return b.manager.AllAccounts(), nil } -func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared.Account, err error) { - authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and .AuthParams is not a pointer. - authParams.AuthorizationType = authority.AccountByID - authParams.HomeAccountID = homeAccountID - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) +func (b Client) Account(ctx context.Context, homeAccountID string) (shared.Account, error) { + if b.cacheAccessor != nil { + b.cacheAccessorMu.RLock() + defer b.cacheAccessorMu.RUnlock() + authParams := b.AuthParams // This is a copy, as we don't have a pointer receiver and .AuthParams is not a pointer. + authParams.AuthorizationType = authority.AccountByID + authParams.HomeAccountID = homeAccountID + key := b.AuthParams.CacheKey(false) + err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { - return acct, err + return shared.Account{}, err } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() } - acct = b.manager.Account(homeAccountID) - return acct, err + return b.manager.Account(homeAccountID), nil } // RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account. -func (b Client) RemoveAccount(ctx context.Context, account shared.Account) (err error) { - if s, ok := b.manager.(cache.Serializer); ok { - suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) - if err != nil { - return err - } - defer func() { - err = b.export(ctx, s, suggestedCacheKey, err) - }() +func (b Client) RemoveAccount(ctx context.Context, account shared.Account) error { + if b.cacheAccessor == nil { + b.manager.RemoveAccount(account, b.AuthParams.ClientID) + return nil } - b.manager.RemoveAccount(account, b.AuthParams.ClientID) - return err -} - -// export helps other methods defer exporting the cache after possibly updating its in-memory content. -// err is the error the calling method will return. If err isn't nil, export returns it without -// exporting the cache. -func (b Client) export(ctx context.Context, marshal cache.Marshaler, key string, err error) error { + b.cacheAccessorMu.Lock() + defer b.cacheAccessorMu.Unlock() + key := b.AuthParams.CacheKey(false) + err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key}) if err != nil { return err } - return b.cacheAccessor.Export(ctx, marshal, cache.ExportHints{PartitionKey: key}) + b.manager.RemoveAccount(account, b.AuthParams.ClientID) + return b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key}) } diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 7aac8102..9f114f0d 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -133,6 +133,7 @@ func TestAcquireTokenSilentScopes(t *testing.T) { }, accesstokens.TokenResponse{ AccessToken: fakeAccessToken, + ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes}, IDToken: fakeIDToken, @@ -223,11 +224,23 @@ func TestCacheIOErrors(t *testing.T) { } t.Run(name, func(t *testing.T) { client := fakeClient(t, WithCacheAccessor(&cache)) - _, actual := client.Account(ctx, "...") - if !errors.Is(actual, expected) { - t.Fatalf(`expected "%v", got "%v"`, expected, actual) + if !export { + // Account and AllAccounts don't export the cache, and AcquireTokenSilent does so + // only after redeeming a refresh token, so we test them only for replace errors + _, actual := client.Account(ctx, "...") + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AllAccounts(ctx) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{}) + if cache.replaceErr != nil && !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } } - _, actual = client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential}) + _, actual := client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential}) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } @@ -235,14 +248,6 @@ func TestCacheIOErrors(t *testing.T) { if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } - _, actual = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{}) - if !errors.Is(actual, expected) { - t.Fatalf(`expected "%v", got "%v"`, expected, actual) - } - _, actual = client.AllAccounts(ctx) - if !errors.Is(actual, expected) { - t.Fatalf(`expected "%v", got "%v"`, expected, actual) - } _, actual = client.AuthResultFromToken(ctx, authority.AuthParams{}, accesstokens.TokenResponse{}, true) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) @@ -254,6 +259,45 @@ func TestCacheIOErrors(t *testing.T) { }) } + // testing that AcquireTokenSilent propagates errors from Export requires more elaborate + // setup because that method exports the cache only after acquiring a new access token + t.Run("silent auth export error", func(t *testing.T) { + cache := failCache{} + hid := "uid.utid" + client := fakeClient(t, WithCacheAccessor(&cache)) + // cache fake tokens and app metadata + _, err := client.AuthResultFromToken(ctx, + authority.AuthParams{ + AuthorityInfo: authority.Info{Host: fakeAuthority}, + ClientID: fakeClientID, + HomeAccountID: hid, + Scopes: testScopes, + }, + accesstokens.TokenResponse{ + AccessToken: "at", + ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"}, + GrantedScopes: accesstokens.Scopes{Slice: testScopes}, + IDToken: fakeIDToken, + RefreshToken: "rt", + }, + true, + ) + if err != nil { + t.Fatal(err) + } + // AcquireTokenSilent should return this error after redeeming a refresh token + cache.exportErr = expected + _, actual := client.AcquireTokenSilent(ctx, + AcquireTokenSilentParameters{ + Account: shared.NewAccount(hid, fakeAuthority, "realm", "id", authority.AAD, "upn"), + Scopes: []string{"not-" + testScopes[0]}, + }, + ) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + }) + // when the client fails to acquire a token, it should return an error instead of exporting the cache t.Run("auth error", func(t *testing.T) { cache := failCache{} diff --git a/apps/internal/base/internal/storage/storage.go b/apps/internal/base/internal/storage/storage.go index 1c0471bb..add75192 100644 --- a/apps/internal/base/internal/storage/storage.go +++ b/apps/internal/base/internal/storage/storage.go @@ -83,7 +83,7 @@ func isMatchingScopes(scopesOne []string, scopesTwo string) bool { } // Read reads a storage token from the cache if it exists. -func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams, account shared.Account) (TokenResponse, error) { +func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) { tr := TokenResponse{} homeAccountID := authParameters.HomeAccountID realm := authParameters.AuthorityInfo.Tenant @@ -103,7 +103,8 @@ func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams, accessToken := m.readAccessToken(homeAccountID, aliases, realm, clientID, scopes) tr.AccessToken = accessToken - if account.IsZero() { + if homeAccountID == "" { + // caller didn't specify a user, so there's no reason to search for an ID or refresh token return tr, nil } // errors returned by read* methods indicate a cache miss and are therefore non-fatal. We continue populating @@ -122,7 +123,7 @@ func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams, } } - account, err = m.readAccount(homeAccountID, aliases, realm) + account, err := m.readAccount(homeAccountID, aliases, realm) if err == nil { tr.Account = account } @@ -493,6 +494,8 @@ func (m *Manager) update(cache *Contract) { // Marshal implements cache.Marshaler. func (m *Manager) Marshal() ([]byte, error) { + m.contractMu.RLock() + defer m.contractMu.RUnlock() return json.Marshal(m.contract) } diff --git a/apps/internal/base/internal/storage/storage_test.go b/apps/internal/base/internal/storage/storage_test.go index 0f1f1dd5..cc39c3bc 100644 --- a/apps/internal/base/internal/storage/storage_test.go +++ b/apps/internal/base/internal/storage/storage_test.go @@ -792,7 +792,7 @@ func TestRead(t *testing.T) { manager := newForTest(responder) manager.update(contract) - got, err := manager.Read(context.Background(), authParameters, testAccount) + got, err := manager.Read(context.Background(), authParameters) switch { case err == nil && test.err: t.Errorf("TestRead(%s): got err == nil, want err != nil", test.desc)