Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AzureCLICredential and OnBehalfOfCredential return errors immediately on failure #21219

Merged
merged 4 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
### Breaking Changes

### Bugs Fixed
* One invocation of `AzureCLICredential.GetToken()` and `OnBehalfOfCredential.GetToken()`
can no longer make two authentication attempts

### Other Changes

Expand Down
6 changes: 4 additions & 2 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,11 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
}
sts := mockSTS{
tenant: test.tenant,
tokenRequestCallback: func(r *http.Request) {
tokenRequestCallback: func(r *http.Request) *http.Response {
if actual := strings.Split(r.URL.Path, "/")[1]; actual != test.expected {
t.Fatalf("expected tenant %q, got %q", test.expected, actual)
}
return nil
},
}
c, err := subtest.ctor(policy.ClientOptions{Transport: &sts})
Expand Down Expand Up @@ -598,7 +599,7 @@ func TestClaims(t *testing.T) {
disableCP1 = d
reqs := 0
sts := mockSTS{
tokenRequestCallback: func(r *http.Request) {
tokenRequestCallback: func(r *http.Request) *http.Response {
if err := r.ParseForm(); err != nil {
t.Error(err)
}
Expand All @@ -615,6 +616,7 @@ func TestClaims(t *testing.T) {
t.Fatalf(`unexpected claims "%v"`, actual)
}
}
return nil
},
}
o := azcore.ClientOptions{Transport: &sts}
Expand Down
8 changes: 7 additions & 1 deletion sdk/azidentity/azure_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent
}
cp.init()
c := AzureCLICredential{tokenProvider: cp.tokenProvider}
c.s = newSyncer(credNameAzureCLI, cp.TenantID, cp.AdditionallyAllowedTenants, c.requestToken, c.requestToken)
c.s = newSyncer(
credNameAzureCLI,
cp.TenantID,
c.requestToken,
nil, // this credential doesn't have a silent auth method because the CLI handles caching
syncerOptions{AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants},
)
return &c, nil
}

Expand Down
26 changes: 26 additions & 0 deletions sdk/azidentity/azure_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,32 @@ var (
}
)

func TestAzureCLICredential_Error(t *testing.T) {
// GetToken shouldn't invoke the CLI a second time after a failure
authNs := 0
expected := newCredentialUnavailableError(credNameAzureCLI, "it didn't work")
o := AzureCLICredentialOptions{
tokenProvider: func(context.Context, string, string) ([]byte, error) {
authNs++
return nil, expected
},
}
cred, err := NewAzureCLICredential(&o)
if err != nil {
t.Fatal(err)
chlowell marked this conversation as resolved.
Show resolved Hide resolved
}
_, err = cred.GetToken(context.Background(), testTRO)
if err == nil {
t.Fatal("expected an error")
}
if err != expected {
t.Fatalf("expected %v, got %v", expected, err)
}
if authNs != 1 {
t.Fatalf("expected 1 authN, got %d", authNs)
}
}

func TestAzureCLICredential_GetTokenSuccess(t *testing.T) {
options := AzureCLICredentialOptions{}
options.tokenProvider = mockCLITokenProviderSuccess
Expand Down
8 changes: 7 additions & 1 deletion sdk/azidentity/client_assertion_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ func NewClientAssertionCredential(tenantID, clientID string, getAssertion func(c
return nil, err
}
cac := ClientAssertionCredential{client: c}
cac.s = newSyncer(credNameAssertion, tenantID, options.AdditionallyAllowedTenants, cac.requestToken, cac.silentAuth)
cac.s = newSyncer(
credNameAssertion,
tenantID,
cac.requestToken,
cac.silentAuth,
syncerOptions{AdditionallyAllowedTenants: options.AdditionallyAllowedTenants},
)
return &cac, nil
}

Expand Down
8 changes: 7 additions & 1 deletion sdk/azidentity/client_certificate_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ func NewClientCertificateCredential(tenantID string, clientID string, certs []*x
return nil, err
}
cc := ClientCertificateCredential{client: c}
cc.s = newSyncer(credNameCert, tenantID, options.AdditionallyAllowedTenants, cc.requestToken, cc.silentAuth)
cc.s = newSyncer(
credNameCert,
tenantID,
cc.requestToken,
cc.silentAuth,
syncerOptions{AdditionallyAllowedTenants: options.AdditionallyAllowedTenants},
)
return &cc, nil
}

Expand Down
8 changes: 7 additions & 1 deletion sdk/azidentity/client_secret_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ func NewClientSecretCredential(tenantID string, clientID string, clientSecret st
return nil, err
}
csc := ClientSecretCredential{client: c}
csc.s = newSyncer(credNameSecret, tenantID, options.AdditionallyAllowedTenants, csc.requestToken, csc.silentAuth)
csc.s = newSyncer(
credNameSecret,
tenantID,
csc.requestToken,
csc.silentAuth,
syncerOptions{AdditionallyAllowedTenants: options.AdditionallyAllowedTenants},
)
return &csc, nil
}

Expand Down
3 changes: 2 additions & 1 deletion sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,11 @@ func TestDefaultAzureCredential_TenantID(t *testing.T) {
ClientOptions: policy.ClientOptions{
Transport: &mockSTS{
tenant: expected,
tokenRequestCallback: func(r *http.Request) {
tokenRequestCallback: func(r *http.Request) *http.Response {
if actual := strings.Split(r.URL.Path, "/")[1]; actual != expected {
t.Fatalf("expected tenant %q, got %q", expected, actual)
}
return nil
},
},
},
Expand Down
10 changes: 9 additions & 1 deletion sdk/azidentity/device_code_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC
return nil, err
}
cred := DeviceCodeCredential{client: c, prompt: cp.UserPrompt}
cred.s = newSyncer(credNameDeviceCode, cp.TenantID, cp.AdditionallyAllowedTenants, cred.requestToken, cred.silentAuth)
cred.s = newSyncer(
credNameDeviceCode,
cp.TenantID,
cred.requestToken,
cred.silentAuth,
syncerOptions{
AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants,
},
)
return &cred, nil
}

Expand Down
10 changes: 9 additions & 1 deletion sdk/azidentity/interactive_browser_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,15 @@ func NewInteractiveBrowserCredential(options *InteractiveBrowserCredentialOption
return nil, err
}
ibc := InteractiveBrowserCredential{client: c, options: cp}
ibc.s = newSyncer(credNameBrowser, cp.TenantID, cp.AdditionallyAllowedTenants, ibc.requestToken, ibc.silentAuth)
ibc.s = newSyncer(
credNameBrowser,
cp.TenantID,
ibc.requestToken,
ibc.silentAuth,
syncerOptions{
AdditionallyAllowedTenants: cp.AdditionallyAllowedTenants,
},
)
return &ibc, nil
}

Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/managed_identity_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func NewManagedIdentityCredential(options *ManagedIdentityCredentialOptions) (*M
return nil, err
}
m := ManagedIdentityCredential{client: c, mic: mic}
m.s = newSyncer(credNameManagedIdentity, "", nil, m.requestToken, m.silentAuth)
m.s = newSyncer(credNameManagedIdentity, "", m.requestToken, m.silentAuth, syncerOptions{})
return &m, nil
}

Expand Down
20 changes: 12 additions & 8 deletions sdk/azidentity/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@ import (
)

// mockSTS returns mock Azure AD responses so tests don't have to account for
// MSAL metadata requests. All responses are success responses. Mock access
// tokens expire in 1 hour and have the value of the "tokenValue" constant.
// MSAL metadata requests. By default, all responses are success responses
// having a token which expires in 1 hour and whose value is the "tokenValue"
// constant. Set tokenRequestCallback to return a different *http.Response.
type mockSTS struct {
// tenant to include in metadata responses. This value must match a test's
// expected tenant because metadata tells MSAL where to send token requests.
// Defaults to the "fakeTenantID" constant.
tenant string
// tokenRequestCallback is called for every token request
tokenRequestCallback func(*http.Request)
// tokenRequestCallback is called for every token request. Return nil to
// send a generic success response.
tokenRequestCallback func(*http.Request) *http.Response
}

func (m *mockSTS) Do(req *http.Request) (*http.Response, error) {
res := http.Response{StatusCode: http.StatusOK}
res := &http.Response{StatusCode: http.StatusOK}
tenant := m.tenant
if tenant == "" {
tenant = fakeTenantID
Expand All @@ -39,10 +41,12 @@ func (m *mockSTS) Do(req *http.Request) (*http.Response, error) {
case "devicecode":
res.Body = io.NopCloser(strings.NewReader(`{"device_code":"...","expires_in":600,"interval":60}`))
case "token":
res.Body = io.NopCloser(bytes.NewReader(accessTokenRespSuccess))
if m.tokenRequestCallback != nil {
m.tokenRequestCallback(req)
if r := m.tokenRequestCallback(req); r != nil {
res = r
}
}
res.Body = io.NopCloser(bytes.NewReader(accessTokenRespSuccess))
default:
// User realm metadata request paths look like "/common/UserRealm/user@domain".
// Matching on the UserRealm segment avoids having to know the UPN.
Expand All @@ -54,5 +58,5 @@ func (m *mockSTS) Do(req *http.Request) (*http.Response, error) {
panic("unexpected request " + req.URL.String())
}
}
return &res, nil
return res, nil
}
3 changes: 2 additions & 1 deletion sdk/azidentity/on_behalf_of_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ func newOnBehalfOfCredential(tenantID, clientID, userAssertion string, cred conf
return nil, err
}
obo := OnBehalfOfCredential{assertion: userAssertion, client: c}
obo.s = newSyncer(credNameOBO, tenantID, options.AdditionallyAllowedTenants, obo.requestToken, obo.requestToken)
// this credential doesn't have a silent auth method because MSAL implements that in AcquireTokenOnBehalfOf; GetToken should just call that method, once
obo.s = newSyncer(credNameOBO, tenantID, obo.requestToken, nil, syncerOptions{AdditionallyAllowedTenants: options.AdditionallyAllowedTenants})
return &obo, nil
}

Expand Down
29 changes: 28 additions & 1 deletion sdk/azidentity/on_behalf_of_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package azidentity

import (
"context"
"io"
"net/http"
"strings"
"testing"
Expand Down Expand Up @@ -52,7 +53,7 @@ func TestOnBehalfOfCredential(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
key := struct{}{}
ctx := context.WithValue(context.Background(), key, true)
srv := mockSTS{tokenRequestCallback: func(r *http.Request) {
srv := mockSTS{tokenRequestCallback: func(r *http.Request) *http.Response {
if c := r.Context(); c == nil {
t.Fatal("AcquireTokenOnBehalfOf received no Context")
} else if v := c.Value(key); v == nil || !v.(bool) {
Expand All @@ -70,6 +71,7 @@ func TestOnBehalfOfCredential(t *testing.T) {
if test.sendX5C {
validateX5C(t, certs)(r)
}
return nil
}}
cred, err := test.ctor(&srv)
if err != nil {
Expand All @@ -91,3 +93,28 @@ func TestOnBehalfOfCredential(t *testing.T) {
})
}
}

func TestOnBehalfOfCredential_Error(t *testing.T) {
// GetToken shouldn't send a second token request after the first fails
tokenReqs := 0
cred, err := NewOnBehalfOfCredentialWithSecret("tenant", "clientID", "assertion", "secret", &OnBehalfOfCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: &mockSTS{
tokenRequestCallback: func(*http.Request) *http.Response {
tokenReqs++
return &http.Response{Body: io.NopCloser(strings.NewReader("")), StatusCode: 400}
},
},
},
})
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err == nil {
t.Fatal("expected an error")
}
if tokenReqs != 1 {
t.Fatalf("expected 1 token request, got %d", tokenReqs)
}
}
24 changes: 16 additions & 8 deletions sdk/azidentity/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ type syncer struct {
name, tenant string
}

func newSyncer(name, tenant string, additionalTenants []string, reqToken, silentAuth authFn) *syncer {
type syncerOptions struct {
// AdditionallyAllowedTenants syncer may authenticate to
AdditionallyAllowedTenants []string
}

func newSyncer(name, tenant string, reqToken, silentAuth authFn, opts syncerOptions) *syncer {
return &syncer{
addlTenants: resolveAdditionalTenants(additionalTenants),
addlTenants: resolveAdditionalTenants(opts.AdditionallyAllowedTenants),
mu: &sync.Mutex{},
name: name,
reqToken: reqToken,
Expand All @@ -41,21 +46,24 @@ func newSyncer(name, tenant string, additionalTenants []string, reqToken, silent

// GetToken ensures that only one goroutine authenticates at a time
func (s *syncer) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
at := azcore.AccessToken{}
if len(opts.Scopes) == 0 {
return azcore.AccessToken{}, errors.New(s.name + ".GetToken() requires at least one scope")
return at, errors.New(s.name + ".GetToken() requires at least one scope")
}
// we don't resolve the tenant for managed identities because they can acquire tokens only from their home tenants
if s.name != credNameManagedIdentity {
tenant, err := s.resolveTenant(opts.TenantID)
if err != nil {
return azcore.AccessToken{}, err
return at, err
}
opts.TenantID = tenant
}
at, err := s.silent(ctx, opts)
if err != nil {
var err error
s.mu.Lock()
defer s.mu.Unlock()
if s.silent == nil {
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
at, err = s.reqToken(ctx, opts)
} else if at, err = s.silent(ctx, opts); err != nil {
// cache miss; request a new token
at, err = s.reqToken(ctx, opts)
}
Expand Down
5 changes: 3 additions & 2 deletions sdk/azidentity/syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestResolveTenant(t *testing.T) {
{tenant: "invalid:tenant", expectError: true},
} {
t.Run("", func(t *testing.T) {
s := newSyncer("", defaultTenant, test.allowed, nil, nil)
s := newSyncer("", defaultTenant, nil, nil, syncerOptions{AdditionallyAllowedTenants: test.allowed})
tenant, err := s.resolveTenant(test.tenant)
if err != nil {
if test.expectError {
Expand All @@ -68,7 +68,7 @@ func TestResolveTenant(t *testing.T) {

func TestSyncer(t *testing.T) {
silentAuths, tokenRequests := 0, 0
s := newSyncer("", "tenant", nil,
s := newSyncer("", "tenant",
func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
tokenRequests++
return azcore.AccessToken{}, nil
Expand All @@ -81,6 +81,7 @@ func TestSyncer(t *testing.T) {
silentAuths++
return azcore.AccessToken{}, err
},
syncerOptions{},
)
goroutines := 50
wg := sync.WaitGroup{}
Expand Down
10 changes: 9 additions & 1 deletion sdk/azidentity/username_password_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ func NewUsernamePasswordCredential(tenantID string, clientID string, username st
return nil, err
}
upc := UsernamePasswordCredential{client: c, password: password, username: username}
upc.s = newSyncer(credNameUserPassword, tenantID, options.AdditionallyAllowedTenants, upc.requestToken, upc.silentAuth)
upc.s = newSyncer(
credNameUserPassword,
tenantID,
upc.requestToken,
upc.silentAuth,
syncerOptions{
AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,
},
)
return &upc, nil
}

Expand Down
Loading