From 5009bfb13549d37d003b09267f6b2e0f46ab306c Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 31 Jan 2024 19:40:43 +0000 Subject: [PATCH 1/7] introduce credentialUnavailableError interface --- sdk/azidentity/azure_cli_credential_test.go | 2 +- .../azure_developer_cli_credential_test.go | 2 +- sdk/azidentity/chained_token_credential.go | 2 +- .../chained_token_credential_test.go | 4 ++-- sdk/azidentity/confidential_client.go | 2 +- sdk/azidentity/default_azure_credential.go | 2 +- .../default_azure_credential_test.go | 2 +- sdk/azidentity/developer_credential_util.go | 2 +- sdk/azidentity/errors.go | 22 ++++++++++++------- .../managed_identity_client_test.go | 2 +- .../managed_identity_credential_test.go | 2 +- 11 files changed, 25 insertions(+), 19 deletions(-) diff --git a/sdk/azidentity/azure_cli_credential_test.go b/sdk/azidentity/azure_cli_credential_test.go index 82f09735246c..2e50978babb9 100644 --- a/sdk/azidentity/azure_cli_credential_test.go +++ b/sdk/azidentity/azure_cli_credential_test.go @@ -53,7 +53,7 @@ func TestAzureCLICredential_DefaultChainError(t *testing.T) { t.Fatal(err) } _, err = cred.GetToken(context.Background(), testTRO) - var ue *credentialUnavailableError + var ue credentialUnavailableError if !errors.As(err, &ue) { t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err) } diff --git a/sdk/azidentity/azure_developer_cli_credential_test.go b/sdk/azidentity/azure_developer_cli_credential_test.go index 9aeff5529523..fb6af170f991 100644 --- a/sdk/azidentity/azure_developer_cli_credential_test.go +++ b/sdk/azidentity/azure_developer_cli_credential_test.go @@ -35,7 +35,7 @@ func TestAzureDeveloperCLICredential_DefaultChainError(t *testing.T) { t.Fatal(err) } _, err = cred.GetToken(context.Background(), testTRO) - var ue *credentialUnavailableError + var ue credentialUnavailableError if !errors.As(err, &ue) { t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err) } diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index dc855edf7868..f9a7d76daa3b 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -86,7 +86,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token errs []error successfulCredential azcore.TokenCredential token azcore.AccessToken - unavailableErr *credentialUnavailableError + unavailableErr credentialUnavailableError ) for _, cred := range c.sources { token, err = cred.GetToken(ctx, opts) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index c05c21bcc41a..3332ff0ff5c5 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -139,7 +139,7 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenUnavailable(t *testin t.Fatal(err) } _, err = cred.GetToken(context.Background(), testTRO) - if _, ok := err.(*credentialUnavailableError); !ok { + if _, ok := err.(credentialUnavailableError); !ok { t.Fatalf("expected credentialUnavailableError, received %T", err) } expectedError := `ChainedTokenCredential: failed to acquire a token. @@ -186,7 +186,7 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenCustomName(t *testing } cred.name = "CustomNameCredential" _, err = cred.GetToken(context.Background(), testTRO) - if _, ok := err.(*credentialUnavailableError); !ok { + if _, ok := err.(credentialUnavailableError); !ok { t.Fatalf("expected credentialUnavailableError, received %T", err) } expectedError := `CustomNameCredential: failed to acquire a token. diff --git a/sdk/azidentity/confidential_client.go b/sdk/azidentity/confidential_client.go index 5a88e740fee2..857a6c14192f 100644 --- a/sdk/azidentity/confidential_client.go +++ b/sdk/azidentity/confidential_client.go @@ -109,7 +109,7 @@ func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenReque if err != nil { // We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code. // We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError. - var unavailableErr *credentialUnavailableError + var unavailableErr credentialUnavailableError if !errors.As(err, &unavailableErr) { res := getResponseFromError(err) err = newAuthenticationFailedError(c.name, err.Error(), res, err) diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index 35aeef867478..5270939a21eb 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -158,7 +158,7 @@ type defaultCredentialErrorReporter struct { } func (d *defaultCredentialErrorReporter) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - if _, ok := d.err.(*credentialUnavailableError); ok { + if _, ok := d.err.(credentialUnavailableError); ok { return azcore.AccessToken{}, d.err } return azcore.AccessToken{}, newCredentialUnavailableError(d.credType, d.err.Error()) diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index f9bbb5f80ff8..0f78371e93dd 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -277,7 +277,7 @@ func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) { for i := 0; i < 2; i++ { // expecting credentialUnavailableError because delay exceeds the wrapper's timeout _, err = chain.GetToken(context.Background(), testTRO) - if _, ok := err.(*credentialUnavailableError); !ok { + if _, ok := err.(credentialUnavailableError); !ok { t.Fatalf("expected credentialUnavailableError, got %T: %v", err, err) } } diff --git a/sdk/azidentity/developer_credential_util.go b/sdk/azidentity/developer_credential_util.go index d8b952f532ee..f0bae35918b6 100644 --- a/sdk/azidentity/developer_credential_util.go +++ b/sdk/azidentity/developer_credential_util.go @@ -19,7 +19,7 @@ const cliTimeout = 10 * time.Second // the next credential in its chain (another developer credential). func unavailableIfInChain(err error, inDefaultChain bool) error { if err != nil && inDefaultChain { - var unavailableErr *credentialUnavailableError + var unavailableErr credentialUnavailableError if !errors.As(err, &unavailableErr) { err = newCredentialUnavailableError(credNameAzureDeveloperCLI, err.Error()) } diff --git a/sdk/azidentity/errors.go b/sdk/azidentity/errors.go index 9cc4531ee156..1f3a988c6f27 100644 --- a/sdk/azidentity/errors.go +++ b/sdk/azidentity/errors.go @@ -110,31 +110,37 @@ func (*AuthenticationFailedError) NonRetriable() { var _ errorinfo.NonRetriable = (*AuthenticationFailedError)(nil) -// credentialUnavailableError indicates a credential can't attempt authentication because it lacks required -// data or state -type credentialUnavailableError struct { + +type credentialUnavailableError interface { + error + credentialUnavailable() +} + +type credUnavailableError struct { message string } // newCredentialUnavailableError is an internal helper that ensures consistent error message formatting func newCredentialUnavailableError(credType, message string) error { msg := fmt.Sprintf("%s: %s", credType, message) - return &credentialUnavailableError{msg} + return &credUnavailableError{msg} } // NewCredentialUnavailableError constructs an error indicating a credential can't attempt authentication // because it lacks required data or state. When [ChainedTokenCredential] receives this error it will try // its next credential, if any. func NewCredentialUnavailableError(message string) error { - return &credentialUnavailableError{message} + return &credUnavailableError{message} } // Error implements the error interface. Note that the message contents are not contractual and can change over time. -func (e *credentialUnavailableError) Error() string { +func (e *credUnavailableError) Error() string { return e.message } // NonRetriable is a marker method indicating this error should not be retried. It has no implementation. -func (e *credentialUnavailableError) NonRetriable() {} +func (*credUnavailableError) NonRetriable() {} + +func (*credUnavailableError) credentialUnavailable() {} -var _ errorinfo.NonRetriable = (*credentialUnavailableError)(nil) +var _ errorinfo.NonRetriable = (*credUnavailableError)(nil) diff --git a/sdk/azidentity/managed_identity_client_test.go b/sdk/azidentity/managed_identity_client_test.go index 56ef2340941f..e9db0b0b60e3 100644 --- a/sdk/azidentity/managed_identity_client_test.go +++ b/sdk/azidentity/managed_identity_client_test.go @@ -116,7 +116,7 @@ func TestManagedIdentityClient_IMDSErrors(t *testing.T) { if actual := err.Error(); !strings.Contains(actual, test.body) { t.Fatalf("expected response body in error, got %q", actual) } - var unavailableErr *credentialUnavailableError + var unavailableErr credentialUnavailableError if !errors.As(err, &unavailableErr) { t.Fatalf("expected %T, got %T", unavailableErr, err) } diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index c22aaafccf17..c71c444d96d5 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -300,7 +300,7 @@ func TestManagedIdentityCredential_GetTokenIMDS400(t *testing.T) { // cred should return credentialUnavailableError when IMDS responds 400 to a token request for i := 0; i < 3; i++ { _, err = cred.GetToken(context.Background(), testTRO) - if _, ok := err.(*credentialUnavailableError); !ok { + if _, ok := err.(credentialUnavailableError); !ok { t.Fatalf("expected credentialUnavailableError, received %T", err) } } From 19950aebe86f09e9e3b7b551c509971aa706462b Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 31 Jan 2024 20:47:54 +0000 Subject: [PATCH 2/7] replace ErrAuthenticationRequired with AuthenticationRequiredError --- sdk/azidentity/CHANGELOG.md | 4 +++ sdk/azidentity/azidentity_test.go | 34 +++++++++++++++++-- sdk/azidentity/device_code_credential.go | 4 +-- sdk/azidentity/errors.go | 24 ++++++++++--- .../interactive_browser_credential.go | 4 +-- sdk/azidentity/public_client.go | 2 +- 6 files changed, 60 insertions(+), 12 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index e960660d69a8..7aa88b7ae89d 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -5,6 +5,10 @@ ### Features Added ### Breaking Changes +> These changes affect only code written against a beta version such as v1.6.0-beta.1 +* Replaced `ErrAuthenticationRequired` with `AuthenticationRequiredError`, a struct + type that carries the `TokenRequestOptions` passed to the `GetToken` call which + returned the error. ### Bugs Fixed * Fixed more cases in which credential chains like `DefaultAzureCredential` diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 7239d0aadf48..b103aa36a494 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -21,6 +21,7 @@ import ( "runtime" "strings" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" @@ -114,6 +115,7 @@ func TestUserAuthentication(t *testing.T) { name: credNameBrowser, new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { return NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{ + AdditionallyAllowedTenants: []string{"*"}, AuthenticationRecord: ar, ClientOptions: co, DisableAutomaticAuthentication: disableAutoAuth, @@ -126,6 +128,7 @@ func TestUserAuthentication(t *testing.T) { name: credNameDeviceCode, new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { o := DeviceCodeCredentialOptions{ + AdditionallyAllowedTenants: []string{"*"}, AuthenticationRecord: ar, ClientOptions: co, DisableAutomaticAuthentication: disableAutoAuth, @@ -143,6 +146,7 @@ func TestUserAuthentication(t *testing.T) { name: credNameUserPassword, new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) { opts := UsernamePasswordCredentialOptions{ + AdditionallyAllowedTenants: []string{"*"}, AuthenticationRecord: ar, ClientOptions: co, TokenCachePersistenceOptions: tcpo, @@ -264,8 +268,19 @@ func TestUserAuthentication(t *testing.T) { t.Run("DisableAutomaticAuthentication/"+credential.name, func(t *testing.T) { cred, err := credential.new(nil, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true) require.NoError(t, err) - _, err = cred.GetToken(context.Background(), testTRO) - require.ErrorIs(t, err, ErrAuthenticationRequired) + expected := policy.TokenRequestOptions{ + Claims: "claims", + EnableCAE: true, + Scopes: []string{"scope"}, + TenantID: "tenant", + } + _, err = cred.GetToken(context.Background(), expected) + require.Contains(t, err.Error(), credential.name) + require.Contains(t, err.Error(), "Call Authenticate") + var actual *AuthenticationRequiredError + require.ErrorAs(t, err, &actual) + require.Equal(t, expected, actual.TokenRequestOptions) + if credential.name != credNameBrowser || runManualTests { _, err = cred.Authenticate(context.Background(), &testTRO) require.NoError(t, err) @@ -274,6 +289,18 @@ func TestUserAuthentication(t *testing.T) { require.NoError(t, err) } }) + t.Run("DisableAutomaticAuthentication/ChainedTokenCredential/"+credential.name, func(t *testing.T) { + cred, err := credential.new(nil, policy.ClientOptions{}, AuthenticationRecord{}, true) + require.NoError(t, err) + expected := azcore.AccessToken{ExpiresOn: time.Now().UTC(), Token: tokenValue} + fake := NewFakeCredential() + fake.SetResponse(expected, nil) + chain, err := NewChainedTokenCredential([]azcore.TokenCredential{cred, fake}, nil) + require.NoError(t, err) + actual, err := chain.GetToken(context.Background(), testTRO) + require.NoError(t, err) + require.Equal(t, expected, actual) + }) } } } @@ -635,7 +662,8 @@ func TestAdditionallyAllowedTenants(t *testing.T) { // tenant resolution should have succeeded because the specified tenant is allowed, // however the credential should have returned a different error because automatic // authentication is disabled - require.ErrorIs(t, ErrAuthenticationRequired, err) + var e *AuthenticationRequiredError + require.ErrorAs(t, err, &e) } }) diff --git a/sdk/azidentity/device_code_credential.go b/sdk/azidentity/device_code_credential.go index 65390b9492bb..29a73e96e842 100644 --- a/sdk/azidentity/device_code_credential.go +++ b/sdk/azidentity/device_code_credential.go @@ -34,8 +34,8 @@ type DeviceCodeCredentialOptions struct { ClientID string // DisableAutomaticAuthentication prevents the credential from automatically prompting the user to authenticate. - // When this option is true, [DeviceCodeCredential.GetToken] will return [ErrAuthenticationRequired] when user - // interaction is necessary to acquire a token. + // When this option is true, GetToken will return AuthenticationRequiredError when user interaction is necessary + // to acquire a token. DisableAutomaticAuthentication bool // DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or diff --git a/sdk/azidentity/errors.go b/sdk/azidentity/errors.go index 1f3a988c6f27..6debc0909a70 100644 --- a/sdk/azidentity/errors.go +++ b/sdk/azidentity/errors.go @@ -13,15 +13,12 @@ import ( "fmt" "net/http" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" msal "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" ) -// ErrAuthenticationRequired indicates a credential's Authenticate method must be called to acquire a token -// because user interaction is required and the credential is configured not to automatically prompt the user. -var ErrAuthenticationRequired error = &credentialUnavailableError{"can't acquire a token without user interaction. Call Authenticate to interactively authenticate a user"} - // getResponseFromError retrieves the response carried by // an AuthenticationFailedError or MSAL CallErr, if any func getResponseFromError(err error) *http.Response { @@ -110,6 +107,25 @@ func (*AuthenticationFailedError) NonRetriable() { var _ errorinfo.NonRetriable = (*AuthenticationFailedError)(nil) +// AuthenticationRequiredError indicates a credential's Authenticate method must be called to acquire a token +// because the credential requires user interaction and is configured not to request it automatically. +type AuthenticationRequiredError struct { + credUnavailableError + + // TokenRequestOptions for the required token. Pass this to the credential's Authenticate method. + TokenRequestOptions policy.TokenRequestOptions +} + +func newAuthenticationRequiredError(credType string, tro policy.TokenRequestOptions) error { + return &AuthenticationRequiredError{ + credUnavailableError: credUnavailableError{ + credType + " can't acquire a token without user interaction. Call Authenticate to authenticate a user interactively", + }, + TokenRequestOptions: tro, + } +} + +var _ errorinfo.NonRetriable = (*AuthenticationRequiredError)(nil) type credentialUnavailableError interface { error diff --git a/sdk/azidentity/interactive_browser_credential.go b/sdk/azidentity/interactive_browser_credential.go index 4e04ecdab097..ad6bdaf69189 100644 --- a/sdk/azidentity/interactive_browser_credential.go +++ b/sdk/azidentity/interactive_browser_credential.go @@ -33,8 +33,8 @@ type InteractiveBrowserCredentialOptions struct { ClientID string // DisableAutomaticAuthentication prevents the credential from automatically prompting the user to authenticate. - // When this option is true, [InteractiveBrowserCredential.GetToken] will return [ErrAuthenticationRequired] when - // user interaction is necessary to acquire a token. + // When this option is true, GetToken will return AuthenticationRequiredError when user interaction is necessary + // to acquire a token. DisableAutomaticAuthentication bool // DisableInstanceDiscovery should be set true only by applications authenticating in disconnected clouds, or diff --git a/sdk/azidentity/public_client.go b/sdk/azidentity/public_client.go index 81d8632fa96e..e76cb3bab4e6 100644 --- a/sdk/azidentity/public_client.go +++ b/sdk/azidentity/public_client.go @@ -152,7 +152,7 @@ func (p *publicClient) GetToken(ctx context.Context, tro policy.TokenRequestOpti return p.token(ar, err) } if p.opts.DisableAutomaticAuthentication { - return azcore.AccessToken{}, ErrAuthenticationRequired + return azcore.AccessToken{}, newAuthenticationRequiredError(p.name, tro) } at, err := p.reqToken(ctx, client, tro) if err == nil { From bca87b5cdd52253933f8200e6825f5e62743f098 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 1 Feb 2024 00:12:32 +0000 Subject: [PATCH 3/7] rewrite examples --- sdk/azidentity/example_cache_test.go | 32 --------- sdk/azidentity/example_shared_test.go | 19 +++-- sdk/azidentity/example_test.go | 38 ++++++++++ sdk/azidentity/example_user_auth_test.go | 89 ++++++++++++------------ 4 files changed, 97 insertions(+), 81 deletions(-) delete mode 100644 sdk/azidentity/example_cache_test.go diff --git a/sdk/azidentity/example_cache_test.go b/sdk/azidentity/example_cache_test.go deleted file mode 100644 index 4108aace412d..000000000000 --- a/sdk/azidentity/example_cache_test.go +++ /dev/null @@ -1,32 +0,0 @@ -//go:build go1.18 -// +build go1.18 - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -package azidentity_test - -import ( - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - - // importing the cache module registers the cache implementation for the current platform - _ "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache" -) - -// Credentials, excepting those that authenticate via external tools like [AzureCLICredential], -// cache authentication data in memory by default. Most of these credentials also support optional -// persistent caching. This example shows how to enable and configure that for a credential. It -// shows only [InteractiveBrowserCredential], however all credentials that support persistent caching have -// the same [TokenCachePersistenceOptions] API. -func Example_persistentCache() { - cred, err := azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{ - // Non-nil TokenCachePersistenceOptions enables persistent caching with default options. - // See TokenCachePersistenceOptions documentation for details of the supported options. - TokenCachePersistenceOptions: &azidentity.TokenCachePersistenceOptions{}, - }) - if err != nil { - // TODO: handle error - } - // TODO: use credential - _ = cred -} diff --git a/sdk/azidentity/example_shared_test.go b/sdk/azidentity/example_shared_test.go index c876d5fbf2fe..e81bf98fc7c6 100644 --- a/sdk/azidentity/example_shared_test.go +++ b/sdk/azidentity/example_shared_test.go @@ -12,12 +12,13 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) -// Helpers and variables to keep the examples tidy +// Helpers, variables, fakes to keep the examples tidy const ( - certPath = "testdata/certificate.pem" - clientID = "fake-client-id" - tenantID = "fake-tenant" + authRecordPath = "fake/path" + certPath = "testdata/certificate.pem" + clientID = "fake-client-id" + tenantID = "fake-tenant" ) func handleError(err error) { @@ -28,3 +29,13 @@ func handleError(err error) { var cred azcore.TokenCredential var err error + +type exampleServiceClient struct{} + +func newServiceClient(azcore.TokenCredential) (exampleServiceClient, error) { + return exampleServiceClient{}, nil +} + +func (exampleServiceClient) Method() error { + return nil +} diff --git a/sdk/azidentity/example_test.go b/sdk/azidentity/example_test.go index 8e941493e4cf..d304cd74705d 100644 --- a/sdk/azidentity/example_test.go +++ b/sdk/azidentity/example_test.go @@ -7,11 +7,49 @@ package azidentity_test import ( + "context" + "errors" "os" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" ) +// Credentials that require user interaction such as [InteractiveBrowserCredential] and [DeviceCodeCredential] +// can optionally return this error instead of automatically prompting for user interaction. This allows applications +// to decide when to request user interaction. This example shows how to handle the error and authenticate a user +// interactively. It shows [InteractiveBrowserCredential] but the same pattern applies to [DeviceCodeCredential]. +func ExampleAuthenticationRequiredError() { + cred, err := azidentity.NewInteractiveBrowserCredential( + &azidentity.InteractiveBrowserCredentialOptions{ + // This option is useful only for applications that need to control when to prompt users to + // authenticate. If the timing of user interaction isn't important, don't set this option. + DisableAutomaticAuthentication: true, + }, + ) + if err != nil { + // TODO: handle error + } + // this could be any client that authenticates with an azidentity credential + client, err := newServiceClient(cred) + if err != nil { + // TODO: handle error + } + err = client.Method() + if err != nil { + var are *azidentity.AuthenticationRequiredError + if errors.As(err, &are) { + // The client requested a token and the credential requires user interaction. Whenever it's convenient + // for the application, call Authenticate to prompt the user. Pass the error's TokenRequestOptions to + // request a token with the parameters the client specified. + _, err = cred.Authenticate(context.TODO(), &are.TokenRequestOptions) + if err != nil { + // TODO: handle error + } + // TODO: retry the client method; it should succeed because the credential now has the necessary token + } + } +} + func ExampleNewOnBehalfOfCredentialWithCertificate() { data, err := os.ReadFile(certPath) if err != nil { diff --git a/sdk/azidentity/example_user_auth_test.go b/sdk/azidentity/example_user_auth_test.go index 26c7086ee7eb..4f29e3c62583 100644 --- a/sdk/azidentity/example_user_auth_test.go +++ b/sdk/azidentity/example_user_auth_test.go @@ -9,65 +9,64 @@ package azidentity_test import ( "context" "encoding/json" + "os" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" -) - -// This example shows how to authenticate a user with [InteractiveBrowserCredential], enabling persistent -// caching so that the user doesn't need to authenticate interactively the next time the application runs. -func Example_userAuthentication() { - cred, err := azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{ - // By default, credentials begin interactive authentication whenever necessary. To instead control when - // a credential prompts for user interaction, set this option true. The credential will then return - // azidentity.ErrAuthenticationRequired instead of prompting for authentication. The application - // can then call the credential's Authenticate method when it's convenient to prompt the user. - DisableAutomaticAuthentication: true, - // By default, credentials cache in memory. Set TokenCachePersistenceOptions to enable persistent caching. - TokenCachePersistenceOptions: &azidentity.TokenCachePersistenceOptions{ - // optionally set Name to isolate this credential's cache from other applications - Name: "myapp", - }, - }) - if err != nil { - // TODO: handle error - } + // importing the cache module registers the cache implementation for the current platform + _ "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache" +) - // The Authenticate method begins interactive authentication. Call it whenever it's convenient for - // your application to authenticate a user. If Authenticate succeeds, the credential is ready for - // use with a client. - record, err := cred.Authenticate(context.TODO(), nil) - if err != nil { - // TODO: handle error +// this example shows file storage but any form of byte storage would work +func retrieveRecord() (azidentity.AuthenticationRecord, error) { + record := azidentity.AuthenticationRecord{} + b, err := os.ReadFile(authRecordPath) + if err == nil { + err = json.Unmarshal(b, &record) } + return record, err +} - // The record contains no authentication secrets. You can marshal it for storage. +func storeRecord(record azidentity.AuthenticationRecord) error { b, err := json.Marshal(record) - if err != nil { - // TODO: handle error + if err == nil { + err = os.WriteFile(authRecordPath, b, 0600) } - // TODO: store bytes - _ = b + return err +} - // An authentication record stored by your application enables other credentials to access data from - // past authentications. If the cache contains sufficient data, your application won't need to prompt - // for authentication. - var unmarshaled azidentity.AuthenticationRecord - err = json.Unmarshal(b, &unmarshaled) +// This example shows how to cache authentication data persistently so a user doesn't need to authenticate +// interactively every time the application runs. The example uses [InteractiveBrowserCredential], however +// [DeviceCodeCredential] has the same API. The key steps are: +// +// 1. Enable persistent caching by importing "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache" and +// setting [TokenCachePersistenceOptions] +// 2. Call Authenticate to acquire an [AuthenticationRecord] and store that for future use. An [AuthenticationRecord] +// enables credentials to access data in the persistent cache. The record contains no authentication secrets. +// 3. Add the [AuthenticationRecord] to the credential's options +func Example_persistentUserAuthentication() { + record, err := retrieveRecord() if err != nil { // TODO: handle error } - - // this credential will be able to access authentication data cached by cred above, even in another process - newCred, err := azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{ - AuthenticationRecord: unmarshaled, - DisableAutomaticAuthentication: true, - TokenCachePersistenceOptions: &azidentity.TokenCachePersistenceOptions{ - Name: "myapp", - }, + cred, err := azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{ + AuthenticationRecord: record, + // Credentials cache in memory by default. Set TokenCachePersistenceOptions to enable persistent caching. + TokenCachePersistenceOptions: &azidentity.TokenCachePersistenceOptions{}, }) if err != nil { // TODO: handle error } - _ = newCred + + if record == (azidentity.AuthenticationRecord{}) { + // No stored record; call Authenticate to acquire one + record, err = cred.Authenticate(context.TODO(), nil) + if err != nil { + // TODO: handle error + } + err = storeRecord(record) + if err != nil { + // TODO: handle error + } + } } From a28f70974efb16b8b5377bcd302f71f16be8cc0c Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 1 Feb 2024 16:23:23 +0000 Subject: [PATCH 4/7] second thoughts re naming --- sdk/azidentity/azure_cli_credential_test.go | 2 +- .../azure_developer_cli_credential_test.go | 2 +- sdk/azidentity/chained_token_credential.go | 2 +- .../chained_token_credential_test.go | 8 ++++---- sdk/azidentity/confidential_client.go | 2 +- sdk/azidentity/default_azure_credential.go | 2 +- .../default_azure_credential_test.go | 2 +- sdk/azidentity/developer_credential_util.go | 2 +- sdk/azidentity/errors.go | 20 +++++++++---------- .../managed_identity_client_test.go | 2 +- .../managed_identity_credential_test.go | 2 +- 11 files changed, 23 insertions(+), 23 deletions(-) diff --git a/sdk/azidentity/azure_cli_credential_test.go b/sdk/azidentity/azure_cli_credential_test.go index 2e50978babb9..0de9325c4d06 100644 --- a/sdk/azidentity/azure_cli_credential_test.go +++ b/sdk/azidentity/azure_cli_credential_test.go @@ -53,7 +53,7 @@ func TestAzureCLICredential_DefaultChainError(t *testing.T) { t.Fatal(err) } _, err = cred.GetToken(context.Background(), testTRO) - var ue credentialUnavailableError + var ue credentialUnavailable if !errors.As(err, &ue) { t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err) } diff --git a/sdk/azidentity/azure_developer_cli_credential_test.go b/sdk/azidentity/azure_developer_cli_credential_test.go index fb6af170f991..0f09211078e2 100644 --- a/sdk/azidentity/azure_developer_cli_credential_test.go +++ b/sdk/azidentity/azure_developer_cli_credential_test.go @@ -35,7 +35,7 @@ func TestAzureDeveloperCLICredential_DefaultChainError(t *testing.T) { t.Fatal(err) } _, err = cred.GetToken(context.Background(), testTRO) - var ue credentialUnavailableError + var ue credentialUnavailable if !errors.As(err, &ue) { t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err) } diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index f9a7d76daa3b..6c35a941b976 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -86,7 +86,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token errs []error successfulCredential azcore.TokenCredential token azcore.AccessToken - unavailableErr credentialUnavailableError + unavailableErr credentialUnavailable ) for _, cred := range c.sources { token, err = cred.GetToken(ctx, opts) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 3332ff0ff5c5..04cb1eb74ad2 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -139,8 +139,8 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenUnavailable(t *testin t.Fatal(err) } _, err = cred.GetToken(context.Background(), testTRO) - if _, ok := err.(credentialUnavailableError); !ok { - t.Fatalf("expected credentialUnavailableError, received %T", err) + if _, ok := err.(credentialUnavailable); !ok { + t.Fatalf("expected credentialUnavailable, received %T", err) } expectedError := `ChainedTokenCredential: failed to acquire a token. Attempted credentials: @@ -186,8 +186,8 @@ func TestChainedTokenCredential_MultipleCredentialsGetTokenCustomName(t *testing } cred.name = "CustomNameCredential" _, err = cred.GetToken(context.Background(), testTRO) - if _, ok := err.(credentialUnavailableError); !ok { - t.Fatalf("expected credentialUnavailableError, received %T", err) + if _, ok := err.(credentialUnavailable); !ok { + t.Fatalf("expected credentialUnavailable, received %T", err) } expectedError := `CustomNameCredential: failed to acquire a token. Attempted credentials: diff --git a/sdk/azidentity/confidential_client.go b/sdk/azidentity/confidential_client.go index 857a6c14192f..01446a7242a4 100644 --- a/sdk/azidentity/confidential_client.go +++ b/sdk/azidentity/confidential_client.go @@ -109,7 +109,7 @@ func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenReque if err != nil { // We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code. // We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError. - var unavailableErr credentialUnavailableError + var unavailableErr credentialUnavailable if !errors.As(err, &unavailableErr) { res := getResponseFromError(err) err = newAuthenticationFailedError(c.name, err.Error(), res, err) diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index 5270939a21eb..2385c986f9cf 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -158,7 +158,7 @@ type defaultCredentialErrorReporter struct { } func (d *defaultCredentialErrorReporter) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { - if _, ok := d.err.(credentialUnavailableError); ok { + if _, ok := d.err.(credentialUnavailable); ok { return azcore.AccessToken{}, d.err } return azcore.AccessToken{}, newCredentialUnavailableError(d.credType, d.err.Error()) diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 0f78371e93dd..45506762cab2 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -277,7 +277,7 @@ func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) { for i := 0; i < 2; i++ { // expecting credentialUnavailableError because delay exceeds the wrapper's timeout _, err = chain.GetToken(context.Background(), testTRO) - if _, ok := err.(credentialUnavailableError); !ok { + if _, ok := err.(credentialUnavailable); !ok { t.Fatalf("expected credentialUnavailableError, got %T: %v", err, err) } } diff --git a/sdk/azidentity/developer_credential_util.go b/sdk/azidentity/developer_credential_util.go index f0bae35918b6..be963d3a2af0 100644 --- a/sdk/azidentity/developer_credential_util.go +++ b/sdk/azidentity/developer_credential_util.go @@ -19,7 +19,7 @@ const cliTimeout = 10 * time.Second // the next credential in its chain (another developer credential). func unavailableIfInChain(err error, inDefaultChain bool) error { if err != nil && inDefaultChain { - var unavailableErr credentialUnavailableError + var unavailableErr credentialUnavailable if !errors.As(err, &unavailableErr) { err = newCredentialUnavailableError(credNameAzureDeveloperCLI, err.Error()) } diff --git a/sdk/azidentity/errors.go b/sdk/azidentity/errors.go index 6debc0909a70..19a6c260027b 100644 --- a/sdk/azidentity/errors.go +++ b/sdk/azidentity/errors.go @@ -110,7 +110,7 @@ var _ errorinfo.NonRetriable = (*AuthenticationFailedError)(nil) // AuthenticationRequiredError indicates a credential's Authenticate method must be called to acquire a token // because the credential requires user interaction and is configured not to request it automatically. type AuthenticationRequiredError struct { - credUnavailableError + credentialUnavailableError // TokenRequestOptions for the required token. Pass this to the credential's Authenticate method. TokenRequestOptions policy.TokenRequestOptions @@ -118,7 +118,7 @@ type AuthenticationRequiredError struct { func newAuthenticationRequiredError(credType string, tro policy.TokenRequestOptions) error { return &AuthenticationRequiredError{ - credUnavailableError: credUnavailableError{ + credentialUnavailableError: credentialUnavailableError{ credType + " can't acquire a token without user interaction. Call Authenticate to authenticate a user interactively", }, TokenRequestOptions: tro, @@ -127,36 +127,36 @@ func newAuthenticationRequiredError(credType string, tro policy.TokenRequestOpti var _ errorinfo.NonRetriable = (*AuthenticationRequiredError)(nil) -type credentialUnavailableError interface { +type credentialUnavailable interface { error credentialUnavailable() } -type credUnavailableError struct { +type credentialUnavailableError struct { message string } // newCredentialUnavailableError is an internal helper that ensures consistent error message formatting func newCredentialUnavailableError(credType, message string) error { msg := fmt.Sprintf("%s: %s", credType, message) - return &credUnavailableError{msg} + return &credentialUnavailableError{msg} } // NewCredentialUnavailableError constructs an error indicating a credential can't attempt authentication // because it lacks required data or state. When [ChainedTokenCredential] receives this error it will try // its next credential, if any. func NewCredentialUnavailableError(message string) error { - return &credUnavailableError{message} + return &credentialUnavailableError{message} } // Error implements the error interface. Note that the message contents are not contractual and can change over time. -func (e *credUnavailableError) Error() string { +func (e *credentialUnavailableError) Error() string { return e.message } // NonRetriable is a marker method indicating this error should not be retried. It has no implementation. -func (*credUnavailableError) NonRetriable() {} +func (*credentialUnavailableError) NonRetriable() {} -func (*credUnavailableError) credentialUnavailable() {} +func (*credentialUnavailableError) credentialUnavailable() {} -var _ errorinfo.NonRetriable = (*credUnavailableError)(nil) +var _ errorinfo.NonRetriable = (*credentialUnavailableError)(nil) diff --git a/sdk/azidentity/managed_identity_client_test.go b/sdk/azidentity/managed_identity_client_test.go index e9db0b0b60e3..78564a261cfe 100644 --- a/sdk/azidentity/managed_identity_client_test.go +++ b/sdk/azidentity/managed_identity_client_test.go @@ -116,7 +116,7 @@ func TestManagedIdentityClient_IMDSErrors(t *testing.T) { if actual := err.Error(); !strings.Contains(actual, test.body) { t.Fatalf("expected response body in error, got %q", actual) } - var unavailableErr credentialUnavailableError + var unavailableErr credentialUnavailable if !errors.As(err, &unavailableErr) { t.Fatalf("expected %T, got %T", unavailableErr, err) } diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index c71c444d96d5..d5af4e6bb81b 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -300,7 +300,7 @@ func TestManagedIdentityCredential_GetTokenIMDS400(t *testing.T) { // cred should return credentialUnavailableError when IMDS responds 400 to a token request for i := 0; i < 3; i++ { _, err = cred.GetToken(context.Background(), testTRO) - if _, ok := err.(credentialUnavailableError); !ok { + if _, ok := err.(credentialUnavailable); !ok { t.Fatalf("expected credentialUnavailableError, received %T", err) } } From b4f1e7191b0e3b5fccb73ff9275d79826e6559bf Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 1 Feb 2024 18:10:44 +0000 Subject: [PATCH 5/7] fix up some test strings --- sdk/azidentity/azure_cli_credential_test.go | 6 +++--- sdk/azidentity/azure_developer_cli_credential_test.go | 6 +++--- sdk/azidentity/default_azure_credential_test.go | 2 +- sdk/azidentity/example_test.go | 2 +- sdk/azidentity/managed_identity_credential_test.go | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sdk/azidentity/azure_cli_credential_test.go b/sdk/azidentity/azure_cli_credential_test.go index 0de9325c4d06..7f026a9b678a 100644 --- a/sdk/azidentity/azure_cli_credential_test.go +++ b/sdk/azidentity/azure_cli_credential_test.go @@ -53,9 +53,9 @@ func TestAzureCLICredential_DefaultChainError(t *testing.T) { t.Fatal(err) } _, err = cred.GetToken(context.Background(), testTRO) - var ue credentialUnavailable - if !errors.As(err, &ue) { - t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err) + var cu credentialUnavailable + if !errors.As(err, &cu) { + t.Fatalf("expected %T, got %T: %q", cu, err, err) } } diff --git a/sdk/azidentity/azure_developer_cli_credential_test.go b/sdk/azidentity/azure_developer_cli_credential_test.go index 0f09211078e2..f452ccfed55c 100644 --- a/sdk/azidentity/azure_developer_cli_credential_test.go +++ b/sdk/azidentity/azure_developer_cli_credential_test.go @@ -35,9 +35,9 @@ func TestAzureDeveloperCLICredential_DefaultChainError(t *testing.T) { t.Fatal(err) } _, err = cred.GetToken(context.Background(), testTRO) - var ue credentialUnavailable - if !errors.As(err, &ue) { - t.Fatalf("expected credentialUnavailableError, got %T: %q", err, err) + var cu credentialUnavailable + if !errors.As(err, &cu) { + t.Fatalf("expected %T, got %T: %q", cu, err, err) } } diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 45506762cab2..7f6001b2076a 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -278,7 +278,7 @@ func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) { // expecting credentialUnavailableError because delay exceeds the wrapper's timeout _, err = chain.GetToken(context.Background(), testTRO) if _, ok := err.(credentialUnavailable); !ok { - t.Fatalf("expected credentialUnavailableError, got %T: %v", err, err) + t.Fatalf("expected credentialUnavailable, got %T: %v", err, err) } } diff --git a/sdk/azidentity/example_test.go b/sdk/azidentity/example_test.go index d304cd74705d..878876928b6e 100644 --- a/sdk/azidentity/example_test.go +++ b/sdk/azidentity/example_test.go @@ -45,7 +45,7 @@ func ExampleAuthenticationRequiredError() { if err != nil { // TODO: handle error } - // TODO: retry the client method; it should succeed because the credential now has the necessary token + // TODO: retry the client method; it should succeed because the credential now has the required token } } } diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index d5af4e6bb81b..6390775b800c 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -301,7 +301,7 @@ func TestManagedIdentityCredential_GetTokenIMDS400(t *testing.T) { for i := 0; i < 3; i++ { _, err = cred.GetToken(context.Background(), testTRO) if _, ok := err.(credentialUnavailable); !ok { - t.Fatalf("expected credentialUnavailableError, received %T", err) + t.Fatalf("expected credentialUnavailable, received %T", err) } } } From 5dbda1cebc9cabf1fba64c1f672abf4fb05aae17 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 1 Feb 2024 11:23:06 -0800 Subject: [PATCH 6/7] interface checks --- sdk/azidentity/errors.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sdk/azidentity/errors.go b/sdk/azidentity/errors.go index 19a6c260027b..19db7ffa6317 100644 --- a/sdk/azidentity/errors.go +++ b/sdk/azidentity/errors.go @@ -125,7 +125,10 @@ func newAuthenticationRequiredError(credType string, tro policy.TokenRequestOpti } } -var _ errorinfo.NonRetriable = (*AuthenticationRequiredError)(nil) +var ( + _ credentialUnavailable = (*AuthenticationRequiredError)(nil) + _ errorinfo.NonRetriable = (*AuthenticationRequiredError)(nil) +) type credentialUnavailable interface { error @@ -159,4 +162,7 @@ func (*credentialUnavailableError) NonRetriable() {} func (*credentialUnavailableError) credentialUnavailable() {} -var _ errorinfo.NonRetriable = (*credentialUnavailableError)(nil) +var ( + _ credentialUnavailable = (*credentialUnavailableError)(nil) + _ errorinfo.NonRetriable = (*credentialUnavailableError)(nil) +) From f0207dbde02fe48dfe1f2d05e7f63af6b6ddcc99 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 1 Feb 2024 20:55:55 +0000 Subject: [PATCH 7/7] add a comment for future generations --- sdk/azidentity/azidentity_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index b103aa36a494..eae8af424eb6 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -297,6 +297,8 @@ func TestUserAuthentication(t *testing.T) { fake.SetResponse(expected, nil) chain, err := NewChainedTokenCredential([]azcore.TokenCredential{cred, fake}, nil) require.NoError(t, err) + // ChainedTokenCredential should continue iterating when a credential returns + // AuthenticationRequiredError i.e., it should call fake.GetToken() and return the expected token actual, err := chain.GetToken(context.Background(), testTRO) require.NoError(t, err) require.Equal(t, expected, actual)