Skip to content

Commit

Permalink
ExportReplace supports Context (#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
element-of-surprise committed Feb 28, 2023
1 parent 79fd37d commit ec469b9
Show file tree
Hide file tree
Showing 13 changed files with 249 additions and 87 deletions.
17 changes: 12 additions & 5 deletions apps/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ implementers on the format being passed.
*/
package cache

import "context"

// Marshaler marshals data from an internal cache to bytes that can be stored.
type Marshaler interface {
Marshal() ([]byte, error)
Expand All @@ -27,13 +29,18 @@ type Serializer interface {
Unmarshaler
}

// ExportReplace is used export or replace what is in the cache.
// ExportReplace exports and replaces in-memory cache data. It doesn't support nil Context or
// define the outcome of passing one. A Context without a timeout must receive a default timeout
// specified by the implementor. Retries must be implemented inside the implementation.
type ExportReplace interface {
// Replace replaces the cache with what is in external storage.
// key is the suggested key which can be used for partioning the cache
Replace(cache Unmarshaler, key string)
// key is the suggested key which can be used for partitioning the cache.
// Implementors should honor Context cancellations and return a context.Canceled or
// context.DeadlineExceeded in those cases.
Replace(ctx context.Context, cache Unmarshaler, key string) error
// Export writes the binary representation of the cache (cache.Marshal()) to
// external storage. This is considered opaque.
// key is the suggested key which can be used for partioning the cache
Export(cache Marshaler, key string)
// key is the suggested key which can be used for partitioning the cache.
// Context cancellations should be honored as in Replace.
Export(ctx context.Context, cache Marshaler, key string) error
}
12 changes: 5 additions & 7 deletions apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ duplication.
.Net People, Take note on X509:
This uses x509.Certificates and private keys. x509 does not store private keys. .Net
has some x509.Certificate2 thing that has private keys, but that is just some bullcrap that .Net
added, it doesn't exist in real life. Seriously, "x509.Certificate2", bahahahaha. As such I've
put a PEM decoder into here.
added, it doesn't exist in real life. As such I've put a PEM decoder into here.
*/

// TODO(msal): This should have example code for each method on client using Go's example doc framework.
Expand Down Expand Up @@ -702,12 +701,11 @@ func (cca Client) AcquireTokenOnBehalfOf(ctx context.Context, userAssertion stri
}

// Account gets the account in the token cache with the specified homeAccountID.
func (cca Client) Account(homeAccountID string) Account {
return cca.base.Account(homeAccountID)
func (cca Client) Account(ctx context.Context, accountID string) (Account, error) {
return cca.base.Account(ctx, accountID)
}

// RemoveAccount signs the account out and forgets account from token cache.
func (cca Client) RemoveAccount(account Account) error {
cca.base.RemoveAccount(account)
return nil
func (cca Client) RemoveAccount(ctx context.Context, account Account) error {
return cca.base.RemoveAccount(ctx, account)
}
17 changes: 12 additions & 5 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
if tk.AccessToken != token {
t.Fatalf("unexpected access token %s", tk.AccessToken)
}
account := client.Account(tk.Account.HomeAccountID)
account, err := client.Account(context.Background(), tk.Account.HomeAccountID)
if err != nil {
t.Fatal(err)
}
if params.utid == "" {
if actual := account.HomeAccountID; actual != "123-456.123-456" {
t.Fatalf("expected %q, got %q", "123-456.123-456", actual)
Expand Down Expand Up @@ -630,16 +633,20 @@ func TestTokenProviderOptions(t *testing.T) {
// testCache is a simple in-memory cache.ExportReplace implementation
type testCache map[string][]byte

func (c testCache) Export(m cache.Marshaler, key string) {
if v, err := m.Marshal(); err == nil {
func (c testCache) Export(ctx context.Context, m cache.Marshaler, key string) error {
v, err := m.Marshal()
if err == nil {
c[key] = v
}
return err
}

func (c testCache) Replace(u cache.Unmarshaler, key string) {
func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, key string) error {
var err error
if v, has := c[key]; has {
_ = u.Unmarshal(v)
err = u.Unmarshal(v)
}
return err
}

func TestWithCache(t *testing.T) {
Expand Down
126 changes: 89 additions & 37 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ type partitionedManager interface {

type noopCacheAccessor struct{}

func (n noopCacheAccessor) Replace(cache cache.Unmarshaler, key string) {}
func (n noopCacheAccessor) Export(cache cache.Marshaler, key string) {}
func (n noopCacheAccessor) Replace(ctx context.Context, cache cache.Unmarshaler, key string) error {
return nil
}
func (n noopCacheAccessor) Export(ctx context.Context, cache cache.Marshaler, key string) error {
return nil
}

// AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache).
type AcquireTokenSilentParameters struct {
Expand Down Expand Up @@ -279,12 +283,12 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s
return baseURL.String(), nil
}

func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) {
func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (ar AuthResult, err error) {
// when tenant == "", the caller didn't specify a tenant and WithTenant will use the client's configured tenant
tenant := silent.TenantID
authParams, err := b.AuthParams.WithTenant(tenant)
if err != nil {
return AuthResult{}, err
return ar, err
}
authParams.Scopes = silent.Scopes
authParams.HomeAccountID = silent.Account.HomeAccountID
Expand All @@ -296,37 +300,48 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
if authParams.AuthorizationType == authority.ATOnBehalfOf {
if s, ok := b.pmanager.(cache.Serializer); ok {
suggestedCacheKey := authParams.CacheKey(silent.IsAppCache)
b.cacheAccessor.Replace(s, suggestedCacheKey)
defer b.cacheAccessor.Export(s, suggestedCacheKey)
err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey)
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
storageTokenResponse, err = b.pmanager.Read(ctx, authParams)
if err != nil {
return AuthResult{}, err
return ar, err
}
} else {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := authParams.CacheKey(silent.IsAppCache)
b.cacheAccessor.Replace(s, suggestedCacheKey)
defer b.cacheAccessor.Export(s, suggestedCacheKey)
err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey)
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
authParams.AuthorizationType = authority.ATRefreshToken
storageTokenResponse, err = b.manager.Read(ctx, authParams, silent.Account)
if err != nil {
return AuthResult{}, err
return ar, err
}
}

// ignore cached access tokens when given claims
if silent.Claims == "" {
result, err := AuthResultFromStorage(storageTokenResponse)
ar, err = AuthResultFromStorage(storageTokenResponse)
if err == nil {
return result, nil
return ar, err
}
}

// redeem a cached refresh token, if available
if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() {
return AuthResult{}, errors.New("no token found")
err = errors.New("no token found")
return ar, err
}
var cc *accesstokens.Credential
if silent.RequestType == accesstokens.ATConfidential {
Expand All @@ -335,10 +350,11 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen

token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken)
if err != nil {
return AuthResult{}, err
return ar, err
}

return b.AuthResultFromToken(ctx, authParams, token, true)
ar, err = b.AuthResultFromToken(ctx, authParams, token, true)
return ar, err
}

func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) {
Expand Down Expand Up @@ -401,67 +417,103 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq
return ar, err
}

func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) {
func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (ar AuthResult, err error) {
if !cacheWrite {
return NewAuthResult(token, shared.Account{})
}

var account shared.Account
var err error
if authParams.AuthorizationType == authority.ATOnBehalfOf {
if s, ok := b.pmanager.(cache.Serializer); ok {
suggestedCacheKey := token.CacheKey(authParams)
b.cacheAccessor.Replace(s, suggestedCacheKey)
defer b.cacheAccessor.Export(s, suggestedCacheKey)
err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey)
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
account, err = b.pmanager.Write(authParams, token)
if err != nil {
return AuthResult{}, err
return ar, err
}
} else {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := token.CacheKey(authParams)
b.cacheAccessor.Replace(s, suggestedCacheKey)
defer b.cacheAccessor.Export(s, suggestedCacheKey)
err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey)
if err != nil {
return ar, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
account, err = b.manager.Write(authParams, token)
if err != nil {
return AuthResult{}, err
return ar, err
}
}
return NewAuthResult(token, account)
ar, err = NewAuthResult(token, account)
return ar, err
}

func (b Client) AllAccounts() []shared.Account {
func (b Client) AllAccounts(ctx context.Context) (accts []shared.Account, err error) {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
b.cacheAccessor.Replace(s, suggestedCacheKey)
defer b.cacheAccessor.Export(s, suggestedCacheKey)
err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey)
if err != nil {
return accts, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}

accounts := b.manager.AllAccounts()
return accounts
accts = b.manager.AllAccounts()
return accts, err
}

func (b Client) Account(homeAccountID string) shared.Account {
func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared.Account, err error) {
authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and .AuthParams is not a pointer.
authParams.AuthorizationType = authority.AccountByID
authParams.HomeAccountID = homeAccountID
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
b.cacheAccessor.Replace(s, suggestedCacheKey)
defer b.cacheAccessor.Export(s, suggestedCacheKey)
err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey)
if err != nil {
return acct, err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
account := b.manager.Account(homeAccountID)
return account
acct = b.manager.Account(homeAccountID)
return acct, err
}

// RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account.
func (b Client) RemoveAccount(account shared.Account) {
func (b Client) RemoveAccount(ctx context.Context, account shared.Account) (err error) {
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
b.cacheAccessor.Replace(s, suggestedCacheKey)
defer b.cacheAccessor.Export(s, suggestedCacheKey)
err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey)
if err != nil {
return err
}
defer func() {
err = b.export(ctx, s, suggestedCacheKey, err)
}()
}
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return err
}

// export helps other methods defer exporting the cache after possibly updating its in-memory content.
// err is the error the calling method will return. If err isn't nil, export returns it without
// exporting the cache.
func (b Client) export(ctx context.Context, marshal cache.Marshaler, key string, err error) error {
if err != nil {
return err
}
return b.cacheAccessor.Export(ctx, marshal, key)
}
Loading

0 comments on commit ec469b9

Please sign in to comment.