From 5771f54b2f4f8a1987b569a80017da0146c59e18 Mon Sep 17 00:00:00 2001 From: Bogdan Gavril Date: Mon, 6 Feb 2023 21:39:54 +0000 Subject: [PATCH 01/14] Logging audit and logging docs --- apps/design/design.md | 9 +++++++-- apps/errors/error_design.md | 2 +- apps/internal/oauth/oauth.go | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/apps/design/design.md b/apps/design/design.md index 6e1121e4..57a07ad3 100644 --- a/apps/design/design.md +++ b/apps/design/design.md @@ -5,6 +5,7 @@ Contributors: - Keegan Caruso(Keegan.Caruso@microsoft.com) - Joel Hendrix(jhendrix@microsoft.com) - Santiago Gonzalez(Santiago.Gonzalez@microsoft.com) +- Bogdan Gavril (bogavril@microsoft.com) ## History @@ -140,6 +141,10 @@ Since we have representation from the Go SDK team, we might have them go bridge the current implementation using some of that code so its possible for our users to store the cert in Keyvault. -## Notes: +## Logging -Do we need: AcquireTokenSilent()?? Seems like we could just bake this into other acquire calls automatically???? \ No newline at end of file +All logging is done using the `fmt` package. For errors, see [error design](../errors/error_design.md). + +This library does not log personal identifiable information (PII). For a definition of PII, see https://www.microsoft.com/en-us/trust-center/privacy/customer-data-definitions. MSAL Go does not log any of the 3 data categories listed there. + +The library may log information related to your organization, such as tenant id, authority, client id etc. as well as information that cannot be tied to a user such as request correlation id, HTTP status codes etc. diff --git a/apps/errors/error_design.md b/apps/errors/error_design.md index 34a699f4..7ef7862f 100644 --- a/apps/errors/error_design.md +++ b/apps/errors/error_design.md @@ -69,7 +69,7 @@ func (e CallErr) Error() string { // Verbose prints a versbose error message with the request or response. func (e CallErr) Verbose() string { - e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need + e.Resp.Request = nil // This brings in a bunch of TLS stuff we don't need e.Resp.TLS = nil // Same return fmt.Sprintf("%s:\nRequest:\n%s\nResponse:\n%s", e.Err, prettyConf.Sprint(e.Req), prettyConf.Sprint(e.Resp)) } diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index f9108235..e559d18e 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -171,7 +171,7 @@ func (t *Client) UsernamePassword(ctx context.Context, authParams authority.Auth userRealm, err := t.Authority.UserRealm(ctx, authParams) if err != nil { - return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm(user: %s) from authority: %w", authParams.Username, err) + return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm from authority: %w", err) } switch userRealm.AccountType { From 02c91b6318073572206ba008fe1949142c1e1758 Mon Sep 17 00:00:00 2001 From: Bogdan Gavril Date: Tue, 14 Feb 2023 12:22:57 +0000 Subject: [PATCH 02/14] Update apps/design/design.md Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- apps/design/design.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/design/design.md b/apps/design/design.md index 57a07ad3..a3da2ae5 100644 --- a/apps/design/design.md +++ b/apps/design/design.md @@ -143,7 +143,7 @@ cert in Keyvault. ## Logging -All logging is done using the `fmt` package. For errors, see [error design](../errors/error_design.md). +For errors, see [error design](../errors/error_design.md). This library does not log personal identifiable information (PII). For a definition of PII, see https://www.microsoft.com/en-us/trust-center/privacy/customer-data-definitions. MSAL Go does not log any of the 3 data categories listed there. From 6230594e002866b5ba7690a0eac3ad652b72d154 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 15 Feb 2023 12:18:21 -0800 Subject: [PATCH 03/14] Delete vestigial option function types (#383) confidential.AcquireTokenByAuthCodeOption and AcquireTokenSilentOption; public.AcquireTokenByAuthCodeOption, AcquireTokenSilentOption, and InteractiveAuthOption --- apps/confidential/confidential.go | 10 ---------- apps/internal/version/version.go | 2 +- apps/public/public.go | 15 --------------- 3 files changed, 1 insertion(+), 26 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 85a1ba6d..7ae31515 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -581,11 +581,6 @@ type AcquireSilentOption interface { acquireSilentOption() } -// AcquireTokenSilentOption changes options inside AcquireTokenSilentOptions used in .AcquireTokenSilent(). -type AcquireTokenSilentOption func(a *AcquireTokenSilentOptions) - -func (AcquireTokenSilentOption) acquireSilentOption() {} - // WithSilentAccount uses the passed account during an AcquireTokenSilent() call. func WithSilentAccount(account Account) interface { AcquireSilentOption @@ -646,11 +641,6 @@ type AcquireByAuthCodeOption interface { acquireByAuthCodeOption() } -// AcquireTokenByAuthCodeOption changes options inside AcquireTokenByAuthCodeOptions used in .AcquireTokenByAuthCode(). -type AcquireTokenByAuthCodeOption func(a *AcquireTokenByAuthCodeOptions) - -func (AcquireTokenByAuthCodeOption) acquireByAuthCodeOption() {} - // WithChallenge allows you to provide a challenge for the .AcquireTokenByAuthCode() call. func WithChallenge(challenge string) interface { AcquireByAuthCodeOption diff --git a/apps/internal/version/version.go b/apps/internal/version/version.go index c3651e63..4a20fef3 100644 --- a/apps/internal/version/version.go +++ b/apps/internal/version/version.go @@ -5,4 +5,4 @@ package version // Version is the version of this client package that is communicated to the server. -const Version = "0.8.1" +const Version = "0.9.0" diff --git a/apps/public/public.go b/apps/public/public.go index 0a3ffaff..ae6adfc6 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -275,11 +275,6 @@ type AcquireSilentOption interface { acquireSilentOption() } -// AcquireTokenSilentOption changes options inside AcquireTokenSilentOptions used in .AcquireTokenSilent(). -type AcquireTokenSilentOption func(a *AcquireTokenSilentOptions) - -func (AcquireTokenSilentOption) acquireSilentOption() {} - // WithSilentAccount uses the passed account during an AcquireTokenSilent() call. func WithSilentAccount(account Account) interface { AcquireSilentOption @@ -432,11 +427,6 @@ type AcquireByAuthCodeOption interface { acquireByAuthCodeOption() } -// AcquireTokenByAuthCodeOption changes options inside AcquireTokenByAuthCodeOptions used in .AcquireTokenByAuthCode(). -type AcquireTokenByAuthCodeOption func(a *AcquireTokenByAuthCodeOptions) - -func (AcquireTokenByAuthCodeOption) acquireByAuthCodeOption() {} - // WithChallenge allows you to provide a code for the .AcquireTokenByAuthCode() call. func WithChallenge(challenge string) interface { AcquireByAuthCodeOption @@ -509,11 +499,6 @@ type AcquireInteractiveOption interface { acquireInteractiveOption() } -// InteractiveAuthOption changes options inside InteractiveAuthOptions used in .AcquireTokenInteractive(). -type InteractiveAuthOption func(*InteractiveAuthOptions) - -func (InteractiveAuthOption) acquireInteractiveOption() {} - // WithLoginHint pre-populates the login prompt with a username. func WithLoginHint(username string) interface { AcquireInteractiveOption From 0af1c206272388f2cf6b60a40588b97c76304a0b Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 15 Feb 2023 13:53:29 -0800 Subject: [PATCH 04/14] Rename public client CreateAuthCodeURL to AuthCodeURL (#384) --- apps/public/public.go | 40 +++++++++---------- apps/public/public_test.go | 12 +++--- .../devapps/authorization_code_sample.go | 2 +- .../devapps/confidential_auth_code_sample.go | 2 +- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/apps/public/public.go b/apps/public/public.go index ae6adfc6..09ee0336 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -144,21 +144,21 @@ func New(clientID string, options ...Option) (Client, error) { return Client{base}, nil } -// createAuthCodeURLOptions contains options for CreateAuthCodeURL -type createAuthCodeURLOptions struct { +// authCodeURLOptions contains options for AuthCodeURL +type authCodeURLOptions struct { claims, loginHint, tenantID, domainHint string } -// CreateAuthCodeURLOption is implemented by options for CreateAuthCodeURL -type CreateAuthCodeURLOption interface { - createAuthCodeURLOption() +// AuthCodeURLOption is implemented by options for AuthCodeURL +type AuthCodeURLOption interface { + authCodeURLOption() } -// CreateAuthCodeURL creates a URL used to acquire an authorization code. +// AuthCodeURL creates a URL used to acquire an authorization code. // // Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] -func (pca Client) CreateAuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...CreateAuthCodeURLOption) (string, error) { - o := createAuthCodeURLOptions{} +func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { + o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return "", err } @@ -181,7 +181,7 @@ func WithClaims(claims string) interface { AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption } { return struct { @@ -190,7 +190,7 @@ func WithClaims(claims string) interface { AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( @@ -204,7 +204,7 @@ func WithClaims(claims string) interface { t.claims = claims case *AcquireTokenSilentOptions: t.claims = claims - case *createAuthCodeURLOptions: + case *authCodeURLOptions: t.claims = claims case *InteractiveAuthOptions: t.claims = claims @@ -225,7 +225,7 @@ func WithTenantID(tenantID string) interface { AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption } { return struct { @@ -234,7 +234,7 @@ func WithTenantID(tenantID string) interface { AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( @@ -248,7 +248,7 @@ func WithTenantID(tenantID string) interface { t.tenantID = tenantID case *AcquireTokenSilentOptions: t.tenantID = tenantID - case *createAuthCodeURLOptions: + case *authCodeURLOptions: t.tenantID = tenantID case *InteractiveAuthOptions: t.tenantID = tenantID @@ -502,18 +502,18 @@ type AcquireInteractiveOption interface { // WithLoginHint pre-populates the login prompt with a username. func WithLoginHint(username string) interface { AcquireInteractiveOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption } { return struct { AcquireInteractiveOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *createAuthCodeURLOptions: + case *authCodeURLOptions: t.loginHint = username case *InteractiveAuthOptions: t.loginHint = username @@ -529,18 +529,18 @@ func WithLoginHint(username string) interface { // WithDomainHint adds the IdP domain as domain_hint query parameter in the auth url. func WithDomainHint(domain string) interface { AcquireInteractiveOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption } { return struct { AcquireInteractiveOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *createAuthCodeURLOptions: + case *authCodeURLOptions: t.domainHint = domain case *InteractiveAuthOptions: t.domainHint = domain diff --git a/apps/public/public_test.go b/apps/public/public_test.go index 9cf2d0ac..e7faf680 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -233,7 +233,7 @@ func TestAcquireTokenWithTenantID(t *testing.T) { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithTenantID(test.tenant)) case "authcodeURL": - URL, err = client.CreateAuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithTenantID(test.tenant)) + URL, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithTenantID(test.tenant)) case "devicecode": dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope, WithTenantID(test.tenant)) case "interactive": @@ -536,7 +536,7 @@ func TestWithClaims(t *testing.T) { ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithClaims(test.claims)) case "authcodeURL": u := "" - if u, err = client.CreateAuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithClaims(test.claims)); err == nil { + if u, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithClaims(test.claims)); err == nil { var parsed *url.URL if parsed, err = url.Parse(u); err == nil { validate(t, parsed.Query()) @@ -693,7 +693,7 @@ func TestWithLoginHint(t *testing.T) { return fakeBrowserOpenURL(authURL) } acquireOpts := []AcquireInteractiveOption{} - urlOpts := []CreateAuthCodeURLOption{} + urlOpts := []AuthCodeURLOption{} if expectHint { acquireOpts = append(acquireOpts, WithLoginHint(upn)) urlOpts = append(urlOpts, WithLoginHint(upn)) @@ -705,7 +705,7 @@ func TestWithLoginHint(t *testing.T) { if !called { t.Fatal("browserOpenURL wasn't called") } - u, err := client.CreateAuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) + u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) if err == nil { var parsed *url.URL parsed, err = url.Parse(u) @@ -767,7 +767,7 @@ func TestWithDomainHint(t *testing.T) { return fakeBrowserOpenURL(authURL) } var acquireOpts []AcquireInteractiveOption - var urlOpts []CreateAuthCodeURLOption + var urlOpts []AuthCodeURLOption if expectHint { acquireOpts = append(acquireOpts, WithDomainHint(domain)) urlOpts = append(urlOpts, WithDomainHint(domain)) @@ -779,7 +779,7 @@ func TestWithDomainHint(t *testing.T) { if !called { t.Fatal("browserOpenURL wasn't called") } - u, err := client.CreateAuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) + u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) if err == nil { var parsed *url.URL parsed, err = url.Parse(u) diff --git a/apps/tests/devapps/authorization_code_sample.go b/apps/tests/devapps/authorization_code_sample.go index 46c4a269..dd3e792f 100644 --- a/apps/tests/devapps/authorization_code_sample.go +++ b/apps/tests/devapps/authorization_code_sample.go @@ -51,7 +51,7 @@ func redirectToURL(w http.ResponseWriter, r *http.Request) { authCodeURLParams := msal.CreateAuthorizationCodeURLParameters(config.ClientID, config.RedirectURI, config.Scopes) authCodeURLParams.CodeChallenge = config.CodeChallenge authCodeURLParams.State = config.State - authURL, err := publicClientApp.CreateAuthCodeURL(context.Background(), authCodeURLParams) + authURL, err := publicClientApp.AuthCodeURL(context.Background(), authCodeURLParams) if err != nil { log.Fatal(err) } diff --git a/apps/tests/devapps/confidential_auth_code_sample.go b/apps/tests/devapps/confidential_auth_code_sample.go index 35f31974..924a355a 100644 --- a/apps/tests/devapps/confidential_auth_code_sample.go +++ b/apps/tests/devapps/confidential_auth_code_sample.go @@ -17,7 +17,7 @@ func redirectToURLConfidential(w http.ResponseWriter, r *http.Request) { // Getting the URL to redirect to acquire the authorization code authCodeURLParams.CodeChallenge = confidentialConfig.CodeChallenge authCodeURLParams.State = confidentialConfig.State - authURL, err := app.CreateAuthCodeURL(context.Background(), confidentialConfig.ClientID, confidentialConfig.RedirectURI, confidentialConfig.Scopes) + authURL, err := app.AuthCodeURL(context.Background(), confidentialConfig.ClientID, confidentialConfig.RedirectURI, confidentialConfig.Scopes) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return From e83f3b9f0499a655a4802d83dfedf872e1884f7f Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 15 Feb 2023 13:59:02 -0800 Subject: [PATCH 05/14] Delete deprecated confidential.NewCredFromAssertion (#385) --- apps/confidential/confidential.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 7ae31515..4e263ceb 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -179,17 +179,6 @@ func NewCredFromSecret(secret string) (Credential, error) { return Credential{secret: secret}, nil } -// NewCredFromAssertion creates a Credential from a signed assertion. -// -// Deprecated: a Credential created by this function can't refresh the -// assertion when it expires. Use NewCredFromAssertionCallback instead. -func NewCredFromAssertion(assertion string) (Credential, error) { - if assertion == "" { - return Credential{}, errors.New("assertion can't be empty string") - } - return NewCredFromAssertionCallback(func(context.Context, AssertionRequestOptions) (string, error) { return assertion, nil }), nil -} - // NewCredFromAssertionCallback creates a Credential that invokes a callback to get assertions // authenticating the application. The callback must be thread safe. func NewCredFromAssertionCallback(callback func(context.Context, AssertionRequestOptions) (string, error)) Credential { From a382af46df305f111976b6f31e82aa88c98d0dc4 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 15 Feb 2023 15:12:13 -0800 Subject: [PATCH 06/14] Rename confidential.WithAccessor to WithCache (#386) --- apps/confidential/confidential.go | 7 +- apps/confidential/confidential_test.go | 80 +++++++++++++++++-- apps/public/public.go | 2 +- .../devapps/client_certificate_sample.go | 2 +- apps/tests/devapps/client_secret_sample.go | 2 +- 5 files changed, 80 insertions(+), 13 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 4e263ceb..a4c28774 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -256,7 +256,7 @@ type Client struct { // returning Option calls. type Options struct { // Accessor controls cache persistence. - // By default there is no cache persistence. This can be set using the WithAccessor() option. + // By default there is no cache persistence. This can be set using the WithCache() option. Accessor cache.ExportReplace // The host of the Azure Active Directory authority. @@ -300,9 +300,8 @@ func WithAuthority(authority string) Option { } } -// WithAccessor provides a cache accessor that will read and write to some externally managed cache -// that may or may not be shared with other applications. -func WithAccessor(accessor cache.ExportReplace) Option { +// WithCache provides an accessor that will read and write authentication data to an externally managed cache. +func WithCache(accessor cache.ExportReplace) Option { return func(o *Options) { o.Accessor = accessor } diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 65b1c6f0..d21619e2 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" @@ -31,8 +32,6 @@ import ( "github.com/kylelemons/godebug/pretty" ) -const localhost = "http://localhost" - // errorClient is an HTTP client for tests that should fail when confidential.Client sends a request type errorClient struct{} @@ -65,10 +64,12 @@ func TestCertFromPEM(t *testing.T) { } const ( + authorityFmt = "https://%s/%s" fakeClientID = "fake_client_id" fakeTokenEndpoint = "https://fake_authority/fake/token" - token = "fake_token" + localhost = "http://localhost" refresh = "fake_refresh" + token = "fake_token" ) var tokenScope = []string{"the_scope"} @@ -613,7 +614,7 @@ func TestTokenProviderOptions(t *testing.T) { } return TokenProviderResult{AccessToken: accessToken, ExpiresInSeconds: 3600}, nil }) - client, err := New("id", cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } @@ -626,6 +627,73 @@ func TestTokenProviderOptions(t *testing.T) { } } +// testCache is a simple in-memory cache.ExportReplace implementation +type testCache map[string][]byte + +func (c testCache) Export(m cache.Marshaler, key string) { + if v, err := m.Marshal(); err == nil { + c[key] = v + } +} + +func (c testCache) Replace(u cache.Unmarshaler, key string) { + if v, has := c[key]; has { + _ = u.Unmarshal(v) + } +} + +func TestWithCache(t *testing.T) { + cache := make(testCache) + accessToken := "*" + lmo := "login.microsoftonline.com" + tenantA, tenantB := "a", "b" + authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB) + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600))) + + cred, err := NewCredFromSecret("...") + if err != nil { + t.Fatal(err) + } + client, err := New(fakeClientID, cred, WithAuthority(authorityA), WithCache(&cache), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + // The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test + ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } + account := ar.Account + if actual := account.Realm; actual != tenantA { + t.Fatalf(`unexpected realm "%s"`, actual) + } + + // a client configured for a different tenant should be able to authenticate silently with the shared cache's data + client, err = New(fakeClientID, cred, WithAuthority(authorityB), WithCache(&cache), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + // this should succeed because the cache contains an access token from tenantA + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA))) + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA)) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } + // this should fail because the cache doesn't contain an access token from tenantB + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) + if err == nil { + t.Fatal("expected an error because the cache doesn't have an appropriate access token") + } +} + func TestWithClaims(t *testing.T) { cred, err := NewCredFromSecret("secret") if err != nil { @@ -715,7 +783,7 @@ func TestWithClaims(t *testing.T) { ar, err = client.AcquireTokenByAuthCode(ctx, "code", localhost, tokenScope, WithClaims(test.claims)) case "authcodeURL": u := "" - if u, err = client.AuthCodeURL(ctx, "client-id", localhost, tokenScope, WithClaims(test.claims)); err == nil { + if u, err = client.AuthCodeURL(ctx, fakeClientID, localhost, tokenScope, WithClaims(test.claims)); err == nil { var parsed *url.URL if parsed, err = url.Parse(u); err == nil { validate(t, parsed.Query()) @@ -831,7 +899,7 @@ func TestWithTenantID(t *testing.T) { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(test.tenant)) case "authcodeURL": - URL, err = client.AuthCodeURL(ctx, "client-id", localhost, tokenScope, WithTenantID(test.tenant)) + URL, err = client.AuthCodeURL(ctx, fakeClientID, localhost, tokenScope, WithTenantID(test.tenant)) case "credential": ar, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(test.tenant)) case "obo": diff --git a/apps/public/public.go b/apps/public/public.go index 09ee0336..ea687a51 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -87,7 +87,7 @@ func WithAuthority(authority string) Option { } } -// WithCache allows you to set some type of cache for storing authentication tokens. +// WithCache provides an accessor that will read and write authentication data to an externally managed cache. func WithCache(accessor cache.ExportReplace) Option { return func(o *Options) { o.Accessor = accessor diff --git a/apps/tests/devapps/client_certificate_sample.go b/apps/tests/devapps/client_certificate_sample.go index 919efb37..a0dd9306 100644 --- a/apps/tests/devapps/client_certificate_sample.go +++ b/apps/tests/devapps/client_certificate_sample.go @@ -37,7 +37,7 @@ func acquireTokenClientCertificate() { if err != nil { log.Fatal(err) } - app, err := confidential.New(config.ClientID, cred, confidential.WithAuthority(config.Authority), confidential.WithAccessor(cacheAccessor)) + app, err := confidential.New(config.ClientID, cred, confidential.WithAuthority(config.Authority), confidential.WithCache(cacheAccessor)) if err != nil { log.Fatal(err) } diff --git a/apps/tests/devapps/client_secret_sample.go b/apps/tests/devapps/client_secret_sample.go index daa6d60a..03f8d697 100644 --- a/apps/tests/devapps/client_secret_sample.go +++ b/apps/tests/devapps/client_secret_sample.go @@ -17,7 +17,7 @@ func acquireTokenClientSecret() { if err != nil { log.Fatal(err) } - app, err := confidential.New(config.ClientID, cred, confidential.WithAuthority(config.Authority), confidential.WithAccessor(cacheAccessor)) + app, err := confidential.New(config.ClientID, cred, confidential.WithAuthority(config.Authority), confidential.WithCache(cacheAccessor)) if err != nil { log.Fatal(err) } From 10e70c9d853417ac27f4cfb4670d20504949c979 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 16 Feb 2023 10:45:31 -0800 Subject: [PATCH 07/14] Upgrade dependencies (#387) --- go.mod | 10 +++++----- go.sum | 21 +++++++++++---------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index c0067a21..8bb370bd 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module github.com/AzureAD/microsoft-authentication-library-for-go go 1.18 require ( - github.com/golang-jwt/jwt/v4 v4.4.2 - github.com/google/uuid v1.1.1 + github.com/golang-jwt/jwt/v4 v4.4.3 + github.com/google/uuid v1.3.0 github.com/kylelemons/godebug v1.1.0 - github.com/montanaflynn/stats v0.6.6 - github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 + github.com/montanaflynn/stats v0.7.0 + github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 ) -require golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c // indirect +require golang.org/x/sys v0.5.0 // indirect diff --git a/go.sum b/go.sum index ac58e5e3..26fce0bf 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,13 @@ -github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs= -github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= -github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/montanaflynn/stats v0.6.6 h1:Duep6KMIDpY4Yo11iFsvyqJDyfzLF9+sndUKT+v64GQ= -github.com/montanaflynn/stats v0.6.6/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= -github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI= -github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ= -golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c h1:VwygUrnw9jn88c4u8GD3rZQbqrP/tgas88tPUbBxQrk= -golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +github.com/montanaflynn/stats v0.7.0 h1:r3y12KyNxj/Sb/iOE46ws+3mS1+MZca1wlHQFPsY/JU= +github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From bddac96c0f7b085acc9d1a9f02727be718fde2c9 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 23 Feb 2023 09:39:31 -0800 Subject: [PATCH 08/14] Remove unused confidential client UserID (#389) --- apps/confidential/confidential.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index a4c28774..e2216350 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -246,10 +246,6 @@ type Client struct { base base.Client cred *accesstokens.Credential - - // userID is some unique identifier for a user. It actually isn't used by us at all, it - // simply acts as another hint that a confidential.Client is for a single user. - userID string } // Options are optional settings for New(). These options are set using various functions @@ -395,11 +391,6 @@ func New(clientID string, cred Credential, options ...Option) (Client, error) { return Client{base: base, cred: internalCred}, nil } -// UserID is the unique user identifier this client if for. -func (cca Client) UserID() string { - return cca.userID -} - // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { claims, loginHint, tenantID, domainHint string From 7d2056cca47af5dcd1875b34d9c1f9a49688e514 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 23 Feb 2023 13:06:41 -0800 Subject: [PATCH 09/14] Replace NewCredFromCert with NewCredFromCertChain (#391) --- apps/confidential/confidential.go | 15 +++------- apps/confidential/confidential_test.go | 30 +++++++++---------- apps/confidential/examples_test.go | 29 ++---------------- .../devapps/client_certificate_sample.go | 9 +----- 4 files changed, 23 insertions(+), 60 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index e2216350..8433499e 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -63,7 +63,7 @@ type AuthResult = base.AuthResult type Account = shared.Account -// CertFromPEM converts a PEM file (.pem or .key) for use with NewCredFromCert(). The file +// CertFromPEM converts a PEM file (.pem or .key) for use with [NewCredFromCert]. The file // must contain the public certificate and the private key. If a PEM block is encrypted and // password is not an empty string, it attempts to decrypt the PEM blocks using the password. // Multiple certs are due to certificate chaining for use cases like TLS that sign from root to leaf. @@ -185,16 +185,9 @@ func NewCredFromAssertionCallback(callback func(context.Context, AssertionReques return Credential{assertionCallback: callback} } -// NewCredFromCert creates a Credential from an x509.Certificate and an RSA private key. -// CertFromPEM() can be used to get these values from a PEM file. -func NewCredFromCert(cert *x509.Certificate, key crypto.PrivateKey) Credential { - cred, _ := NewCredFromCertChain([]*x509.Certificate{cert}, key) - return cred -} - -// NewCredFromCertChain creates a Credential from a chain of x509.Certificates and an RSA private key -// as returned by CertFromPEM(). -func NewCredFromCertChain(certs []*x509.Certificate, key crypto.PrivateKey) (Credential, error) { +// NewCredFromCert creates a Credential from a certificate or chain of certificates and an RSA private key +// as returned by [CertFromPEM]. +func NewCredFromCert(certs []*x509.Certificate, key crypto.PrivateKey) (Credential, error) { cred := Credential{key: key} k, ok := key.(*rsa.PrivateKey) if !ok { diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index d21619e2..a2f28622 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -366,20 +366,9 @@ func TestAcquireTokenSilentTenants(t *testing.T) { } func TestInvalidCredential(t *testing.T) { - data, err := os.ReadFile("../testdata/test-cert.pem") - if err != nil { - t.Fatal(err) - } - certs, key, err := CertFromPEM(data, "") - if err != nil { - t.Fatal(err) - } for _, cred := range []Credential{ {}, NewCredFromAssertionCallback(nil), - NewCredFromCert(nil, nil), - NewCredFromCert(certs[0], nil), - NewCredFromCert(nil, key), } { t.Run("", func(t *testing.T) { _, err := New(fakeClientID, cred) @@ -390,7 +379,7 @@ func TestInvalidCredential(t *testing.T) { } } -func TestNewCredFromCertChain(t *testing.T) { +func TestNewCredFromCert(t *testing.T) { for _, file := range []struct { path string numCerts int @@ -424,7 +413,7 @@ func TestNewCredFromCertChain(t *testing.T) { t.Fatal("expected an RSA private key") } verifyingKey := &k.PublicKey - cred, err := NewCredFromCertChain(certs, key) + cred, err := NewCredFromCert(certs, key) if err != nil { t.Fatal(err) } @@ -507,7 +496,7 @@ func TestNewCredFromCertChain(t *testing.T) { } } -func TestNewCredFromCertChainError(t *testing.T) { +func TestNewCredFromCertError(t *testing.T) { data, err := os.ReadFile("../testdata/test-cert.pem") if err != nil { t.Fatal(err) @@ -529,12 +518,23 @@ func TestNewCredFromCertChainError(t *testing.T) { {[]*x509.Certificate{nil}, key}, } { t.Run("", func(t *testing.T) { - _, err := NewCredFromCertChain(test.certs, test.key) + _, err := NewCredFromCert(test.certs, test.key) if err == nil { t.Fatal("expected an error") } }) } + + // the key in this file doesn't match the cert loaded above + if data, err = os.ReadFile("../testdata/test-cert-chain.pem"); err != nil { + t.Fatal(err) + } + if _, key, err = CertFromPEM(data, ""); err != nil { + t.Fatal(err) + } + if _, err = NewCredFromCert(certs, key); err == nil { + t.Fatal("expected an error because key doesn't match certs") + } } func TestNewCredFromTokenProvider(t *testing.T) { diff --git a/apps/confidential/examples_test.go b/apps/confidential/examples_test.go index e067e27a..e4c7180b 100644 --- a/apps/confidential/examples_test.go +++ b/apps/confidential/examples_test.go @@ -24,32 +24,9 @@ func ExampleNewCredFromCert_pem() { log.Fatal(err) } - // PEM files can have multiple certs. This is usually for certificate chaining where roots - // sign to leafs. Useful for TLS, not for this use case. - if len(certs) > 1 { - log.Fatal("too many certificates in PEM file") - } - - cred := confidential.NewCredFromCert(certs[0], priv) - fmt.Println(cred) // Simply here so cred is used, otherwise won't compile. -} - -func ExampleNewCredFromCertChain() { - b, err := os.ReadFile("key.pem") + cred, err := confidential.NewCredFromCert(certs, priv) if err != nil { - // TODO: handle error - } - - // CertFromPEM loads certificates and a private key from the PEM content. If - // the content is encrypted, the second argument must be the password. - certs, priv, err := confidential.CertFromPEM(b, "") - if err != nil { - // TODO: handle error - } - - cred, err := confidential.NewCredFromCertChain(certs, priv) - if err != nil { - // TODO: handle error + log.Fatal(err) } - _ = cred + fmt.Println(cred) // Simply here so cred is used, otherwise won't compile. } diff --git a/apps/tests/devapps/client_certificate_sample.go b/apps/tests/devapps/client_certificate_sample.go index a0dd9306..c9bc2e8e 100644 --- a/apps/tests/devapps/client_certificate_sample.go +++ b/apps/tests/devapps/client_certificate_sample.go @@ -26,14 +26,7 @@ func acquireTokenClientCertificate() { if err != nil { log.Fatal(err) } - - // PEM files can have multiple certs. This is usually for certificate chaining where roots - // sign to leafs. Useful for TLS, not for this use case. - if len(certs) > 1 { - log.Fatal("too many certificates in PEM file") - } - - cred := confidential.NewCredFromCert(certs[0], privateKey) + cred, err := confidential.NewCredFromCert(certs, privateKey) if err != nil { log.Fatal(err) } From cb956a2127973e3a4222f8aebc8ecdb83fa23ff8 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 24 Feb 2023 11:14:32 -0800 Subject: [PATCH 10/14] Remove unnecessary options structs (#390) --- apps/confidential/confidential.go | 122 ++++++++++++----------------- apps/public/public.go | 123 +++++++++++++----------------- 2 files changed, 103 insertions(+), 142 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 8433499e..49854786 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -237,68 +237,50 @@ func AutoDetectRegion() string { // For more information, visit https://docs.microsoft.com/azure/active-directory/develop/msal-client-applications type Client struct { base base.Client - cred *accesstokens.Credential } -// Options are optional settings for New(). These options are set using various functions +// clientOptions are optional settings for New(). These options are set using various functions // returning Option calls. -type Options struct { - // Accessor controls cache persistence. - // By default there is no cache persistence. This can be set using the WithCache() option. - Accessor cache.ExportReplace - - // The host of the Azure Active Directory authority. - // The default is https://login.microsoftonline.com/common. This can be changed using the - // WithAuthority() option. - Authority string - - // The HTTP client used for making requests. - // It defaults to a shared http.Client. - HTTPClient ops.HTTPClient - - // SendX5C specifies if x5c claim(public key of the certificate) should be sent to STS. - SendX5C bool - - // Instructs MSAL Go to use an Azure regional token service with sepcified AzureRegion. - AzureRegion string - - capabilities []string - - disableInstanceDiscovery bool +type clientOptions struct { + accessor cache.ExportReplace + authority, azureRegion string + capabilities []string + disableInstanceDiscovery, sendX5C bool + httpClient ops.HTTPClient } -func (o Options) validate() error { - u, err := url.Parse(o.Authority) +func (o clientOptions) validate() error { + u, err := url.Parse(o.authority) if err != nil { - return fmt.Errorf("the Authority(%s) does not parse as a valid URL", o.Authority) + return fmt.Errorf("the Authority(%s) does not parse as a valid URL", o.authority) } if u.Scheme != "https" { - return fmt.Errorf("the Authority(%s) does not appear to use https", o.Authority) + return fmt.Errorf("the Authority(%s) does not appear to use https", o.authority) } return nil } // Option is an optional argument to New(). -type Option func(o *Options) +type Option func(o *clientOptions) // WithAuthority allows you to provide a custom authority for use in the client. func WithAuthority(authority string) Option { - return func(o *Options) { - o.Authority = authority + return func(o *clientOptions) { + o.authority = authority } } // WithCache provides an accessor that will read and write authentication data to an externally managed cache. func WithCache(accessor cache.ExportReplace) Option { - return func(o *Options) { - o.Accessor = accessor + return func(o *clientOptions) { + o.accessor = accessor } } // WithClientCapabilities allows configuring one or more client capabilities such as "CP1" func WithClientCapabilities(capabilities []string) Option { - return func(o *Options) { + return func(o *clientOptions) { // there's no danger of sharing the slice's underlying memory with the application because // this slice is simply passed to base.WithClientCapabilities, which copies its data o.capabilities = capabilities @@ -307,21 +289,21 @@ func WithClientCapabilities(capabilities []string) Option { // WithHTTPClient allows for a custom HTTP client to be set. func WithHTTPClient(httpClient ops.HTTPClient) Option { - return func(o *Options) { - o.HTTPClient = httpClient + return func(o *clientOptions) { + o.httpClient = httpClient } } // WithX5C specifies if x5c claim(public key of the certificate) should be sent to STS to enable Subject Name Issuer Authentication. func WithX5C() Option { - return func(o *Options) { - o.SendX5C = true + return func(o *clientOptions) { + o.sendX5C = true } } // WithInstanceDiscovery set to false to disable authority validation (to support private cloud scenarios) func WithInstanceDiscovery(enabled bool) Option { - return func(o *Options) { + return func(o *clientOptions) { o.disableInstanceDiscovery = !enabled } } @@ -338,8 +320,8 @@ func WithInstanceDiscovery(enabled bool) Option { // If auto-detection fails, the non-regional endpoint will be used. // If an invalid region name is provided, the non-regional endpoint MIGHT be used or the token request MIGHT fail. func WithAzureRegion(val string) Option { - return func(o *Options) { - o.AzureRegion = val + return func(o *clientOptions) { + o.azureRegion = val } } @@ -352,9 +334,9 @@ func New(clientID string, cred Credential, options ...Option) (Client, error) { return Client{}, err } - opts := Options{ - Authority: base.AuthorityPublicCloud, - HTTPClient: shared.DefaultClient, + opts := clientOptions{ + authority: base.AuthorityPublicCloud, + httpClient: shared.DefaultClient, } for _, o := range options { @@ -365,17 +347,17 @@ func New(clientID string, cred Credential, options ...Option) (Client, error) { } baseOpts := []base.Option{ - base.WithCacheAccessor(opts.Accessor), + base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), - base.WithRegionDetection(opts.AzureRegion), - base.WithX5C(opts.SendX5C), + base.WithRegionDetection(opts.azureRegion), + base.WithX5C(opts.sendX5C), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery), } if cred.tokenProvider != nil { // The caller will handle all details of authentication, using Client only as a token cache. baseOpts = append(baseOpts, base.WithInstanceDiscovery(false)) } - base, err := base.New(clientID, opts.Authority, oauth.New(opts.HTTPClient), baseOpts...) + base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), baseOpts...) if err != nil { return Client{}, err } @@ -480,13 +462,13 @@ func WithClaims(claims string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: + case *acquireTokenByAuthCodeOptions: t.claims = claims case *acquireTokenByCredentialOptions: t.claims = claims case *acquireTokenOnBehalfOfOptions: t.claims = claims - case *AcquireTokenSilentOptions: + case *acquireTokenSilentOptions: t.claims = claims case *authCodeURLOptions: t.claims = claims @@ -520,13 +502,13 @@ func WithTenantID(tenantID string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: + case *acquireTokenByAuthCodeOptions: t.tenantID = tenantID case *acquireTokenByCredentialOptions: t.tenantID = tenantID case *acquireTokenOnBehalfOfOptions: t.tenantID = tenantID - case *AcquireTokenSilentOptions: + case *acquireTokenSilentOptions: t.tenantID = tenantID case *authCodeURLOptions: t.tenantID = tenantID @@ -539,12 +521,10 @@ func WithTenantID(tenantID string) interface { } } -// AcquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. +// acquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. // These are set by using various AcquireTokenSilentOption functions. -type AcquireTokenSilentOptions struct { - // Account represents the account to use. To set, use the WithSilentAccount() option. - Account Account - +type acquireTokenSilentOptions struct { + account Account claims, tenantID string } @@ -565,8 +545,8 @@ func WithSilentAccount(account Account) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenSilentOptions: - t.Account = account + case *acquireTokenSilentOptions: + t.account = account default: return fmt.Errorf("unexpected options type %T", a) } @@ -580,7 +560,7 @@ func WithSilentAccount(account Account) interface { // // Options: [WithClaims], [WithSilentAccount], [WithTenantID] func (cca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts ...AcquireSilentOption) (AuthResult, error) { - o := AcquireTokenSilentOptions{} + o := acquireTokenSilentOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } @@ -591,21 +571,19 @@ func (cca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts silentParameters := base.AcquireTokenSilentParameters{ Scopes: scopes, - Account: o.Account, + Account: o.account, RequestType: accesstokens.ATConfidential, Credential: cca.cred, - IsAppCache: o.Account.IsZero(), + IsAppCache: o.account.IsZero(), TenantID: o.tenantID, } return cca.base.AcquireTokenSilent(ctx, silentParameters) } -// AcquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. -type AcquireTokenByAuthCodeOptions struct { - Challenge string - - claims, tenantID string +// acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. +type acquireTokenByAuthCodeOptions struct { + challenge, claims, tenantID string } // AcquireByAuthCodeOption is implemented by options for AcquireTokenByAuthCode @@ -625,8 +603,8 @@ func WithChallenge(challenge string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: - t.Challenge = challenge + case *acquireTokenByAuthCodeOptions: + t.challenge = challenge default: return fmt.Errorf("unexpected options type %T", a) } @@ -641,7 +619,7 @@ func WithChallenge(challenge string) interface { // // Options: [WithChallenge], [WithClaims], [WithTenantID] func (cca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, opts ...AcquireByAuthCodeOption) (AuthResult, error) { - o := AcquireTokenByAuthCodeOptions{} + o := acquireTokenByAuthCodeOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } @@ -649,7 +627,7 @@ func (cca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redir params := base.AcquireTokenAuthCodeParameters{ Scopes: scopes, Code: code, - Challenge: o.Challenge, + Challenge: o.challenge, Claims: o.claims, AppType: accesstokens.ATConfidential, Credential: cca.cred, // This setting differs from public.Client.AcquireTokenByAuthCode diff --git a/apps/public/public.go b/apps/public/public.go index ea687a51..348bc5fc 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -47,27 +47,17 @@ type AuthResult = base.AuthResult type Account = shared.Account -// Options configures the Client's behavior. -type Options struct { - // Accessor controls cache persistence. By default there is no cache persistence. - // This can be set with the WithCache() option. - Accessor cache.ExportReplace - - // The host of the Azure Active Directory authority. The default is https://login.microsoftonline.com/common. - // This can be changed with the WithAuthority() option. - Authority string - - // The HTTP client used for making requests. - // It defaults to a shared http.Client. - HTTPClient ops.HTTPClient - - capabilities []string - +// clientOptions configures the Client's behavior. +type clientOptions struct { + accessor cache.ExportReplace + authority string + capabilities []string disableInstanceDiscovery bool + httpClient ops.HTTPClient } -func (p *Options) validate() error { - u, err := url.Parse(p.Authority) +func (p *clientOptions) validate() error { + u, err := url.Parse(p.authority) if err != nil { return fmt.Errorf("Authority options cannot be URL parsed: %w", err) } @@ -78,25 +68,25 @@ func (p *Options) validate() error { } // Option is an optional argument to the New constructor. -type Option func(o *Options) +type Option func(o *clientOptions) // WithAuthority allows for a custom authority to be set. This must be a valid https url. func WithAuthority(authority string) Option { - return func(o *Options) { - o.Authority = authority + return func(o *clientOptions) { + o.authority = authority } } // WithCache provides an accessor that will read and write authentication data to an externally managed cache. func WithCache(accessor cache.ExportReplace) Option { - return func(o *Options) { - o.Accessor = accessor + return func(o *clientOptions) { + o.accessor = accessor } } // WithClientCapabilities allows configuring one or more client capabilities such as "CP1" func WithClientCapabilities(capabilities []string) Option { - return func(o *Options) { + return func(o *clientOptions) { // there's no danger of sharing the slice's underlying memory with the application because // this slice is simply passed to base.WithClientCapabilities, which copies its data o.capabilities = capabilities @@ -105,14 +95,14 @@ func WithClientCapabilities(capabilities []string) Option { // WithHTTPClient allows for a custom HTTP client to be set. func WithHTTPClient(httpClient ops.HTTPClient) Option { - return func(o *Options) { - o.HTTPClient = httpClient + return func(o *clientOptions) { + o.httpClient = httpClient } } // WithInstanceDiscovery set to false to disable authority validation (to support private cloud scenarios) func WithInstanceDiscovery(enabled bool) Option { - return func(o *Options) { + return func(o *clientOptions) { o.disableInstanceDiscovery = !enabled } } @@ -125,9 +115,9 @@ type Client struct { // New is the constructor for Client. func New(clientID string, options ...Option) (Client, error) { - opts := Options{ - Authority: base.AuthorityPublicCloud, - HTTPClient: shared.DefaultClient, + opts := clientOptions{ + authority: base.AuthorityPublicCloud, + httpClient: shared.DefaultClient, } for _, o := range options { @@ -137,7 +127,7 @@ func New(clientID string, options ...Option) (Client, error) { return Client{}, err } - base, err := base.New(clientID, opts.Authority, oauth.New(opts.HTTPClient), base.WithCacheAccessor(opts.Accessor), base.WithClientCapabilities(opts.capabilities), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery)) + base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery)) if err != nil { return Client{}, err } @@ -196,17 +186,17 @@ func WithClaims(claims string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: + case *acquireTokenByAuthCodeOptions: t.claims = claims case *acquireTokenByDeviceCodeOptions: t.claims = claims case *acquireTokenByUsernamePasswordOptions: t.claims = claims - case *AcquireTokenSilentOptions: + case *acquireTokenSilentOptions: t.claims = claims case *authCodeURLOptions: t.claims = claims - case *InteractiveAuthOptions: + case *interactiveAuthOptions: t.claims = claims default: return fmt.Errorf("unexpected options type %T", a) @@ -240,17 +230,17 @@ func WithTenantID(tenantID string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: + case *acquireTokenByAuthCodeOptions: t.tenantID = tenantID case *acquireTokenByDeviceCodeOptions: t.tenantID = tenantID case *acquireTokenByUsernamePasswordOptions: t.tenantID = tenantID - case *AcquireTokenSilentOptions: + case *acquireTokenSilentOptions: t.tenantID = tenantID case *authCodeURLOptions: t.tenantID = tenantID - case *InteractiveAuthOptions: + case *interactiveAuthOptions: t.tenantID = tenantID default: return fmt.Errorf("unexpected options type %T", a) @@ -261,12 +251,10 @@ func WithTenantID(tenantID string) interface { } } -// AcquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. +// acquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. // These are set by using various AcquireTokenSilentOption functions. -type AcquireTokenSilentOptions struct { - // Account represents the account to use. To set, use the WithSilentAccount() option. - Account Account - +type acquireTokenSilentOptions struct { + account Account claims, tenantID string } @@ -287,8 +275,8 @@ func WithSilentAccount(account Account) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenSilentOptions: - t.Account = account + case *acquireTokenSilentOptions: + t.account = account default: return fmt.Errorf("unexpected options type %T", a) } @@ -302,14 +290,14 @@ func WithSilentAccount(account Account) interface { // // Options: [WithClaims], [WithSilentAccount], [WithTenantID] func (pca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts ...AcquireSilentOption) (AuthResult, error) { - o := AcquireTokenSilentOptions{} + o := acquireTokenSilentOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } silentParameters := base.AcquireTokenSilentParameters{ Scopes: scopes, - Account: o.Account, + Account: o.account, Claims: o.claims, RequestType: accesstokens.ATPublic, IsAppCache: false, @@ -415,11 +403,9 @@ func (pca Client) AcquireTokenByDeviceCode(ctx context.Context, scopes []string, return DeviceCode{Result: dc.Result, authParams: authParams, client: pca, dc: dc}, nil } -// AcquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. -type AcquireTokenByAuthCodeOptions struct { - Challenge string - - claims, tenantID string +// acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. +type acquireTokenByAuthCodeOptions struct { + challenge, claims, tenantID string } // AcquireByAuthCodeOption is implemented by options for AcquireTokenByAuthCode @@ -439,8 +425,8 @@ func WithChallenge(challenge string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: - t.Challenge = challenge + case *acquireTokenByAuthCodeOptions: + t.challenge = challenge default: return fmt.Errorf("unexpected options type %T", a) } @@ -455,7 +441,7 @@ func WithChallenge(challenge string) interface { // // Options: [WithChallenge], [WithClaims], [WithTenantID] func (pca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, opts ...AcquireByAuthCodeOption) (AuthResult, error) { - o := AcquireTokenByAuthCodeOptions{} + o := acquireTokenByAuthCodeOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } @@ -463,7 +449,7 @@ func (pca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redir params := base.AcquireTokenAuthCodeParameters{ Scopes: scopes, Code: code, - Challenge: o.Challenge, + Challenge: o.challenge, Claims: o.claims, AppType: accesstokens.ATPublic, RedirectURI: redirectURI, @@ -485,13 +471,9 @@ func (pca Client) RemoveAccount(account Account) error { return nil } -// InteractiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. -type InteractiveAuthOptions struct { - // Used to specify a custom port for the local server. http://localhost:portnumber - // All other URI components are ignored. - RedirectURI string - - claims, loginHint, tenantID, domainHint string +// interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. +type interactiveAuthOptions struct { + claims, domainHint, loginHint, redirectURI, tenantID string } // AcquireInteractiveOption is implemented by options for AcquireTokenInteractive @@ -515,7 +497,7 @@ func WithLoginHint(username string) interface { switch t := a.(type) { case *authCodeURLOptions: t.loginHint = username - case *InteractiveAuthOptions: + case *interactiveAuthOptions: t.loginHint = username default: return fmt.Errorf("unexpected options type %T", a) @@ -542,7 +524,7 @@ func WithDomainHint(domain string) interface { switch t := a.(type) { case *authCodeURLOptions: t.domainHint = domain - case *InteractiveAuthOptions: + case *interactiveAuthOptions: t.domainHint = domain default: return fmt.Errorf("unexpected options type %T", a) @@ -553,7 +535,8 @@ func WithDomainHint(domain string) interface { } } -// WithRedirectURI uses the specified redirect URI for interactive auth. +// WithRedirectURI sets a port for the local server used in interactive authentication, for +// example http://localhost:port. All URI components other than the port are ignored. func WithRedirectURI(redirectURI string) interface { AcquireInteractiveOption options.CallOption @@ -565,8 +548,8 @@ func WithRedirectURI(redirectURI string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *InteractiveAuthOptions: - t.RedirectURI = redirectURI + case *interactiveAuthOptions: + t.redirectURI = redirectURI default: return fmt.Errorf("unexpected options type %T", a) } @@ -581,7 +564,7 @@ func WithRedirectURI(redirectURI string) interface { // // Options: [WithDomainHint], [WithLoginHint], [WithRedirectURI], [WithTenantID] func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, opts ...AcquireInteractiveOption) (AuthResult, error) { - o := InteractiveAuthOptions{} + o := interactiveAuthOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } @@ -592,8 +575,8 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, return AuthResult{}, err } var redirectURL *url.URL - if o.RedirectURI != "" { - redirectURL, err = url.Parse(o.RedirectURI) + if o.redirectURI != "" { + redirectURL, err = url.Parse(o.redirectURI) if err != nil { return AuthResult{}, err } From 79fd37d58636c3d3928b76b8d608d91c03a29070 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Fri, 24 Feb 2023 15:53:26 -0800 Subject: [PATCH 11/14] Upgrade CI tools (#392) --- .github/workflows/go.yml | 2 +- .github/workflows/golangci-lint.yml | 4 ++-- apps/internal/oauth/oauth.go | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 374451e7..850ad671 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: - go: ["1.19", "1.18"] + go: ["1.19", "1.20"] steps: - name: Set up Go 1.x diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 94fae709..acf23dfc 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -18,13 +18,13 @@ jobs: steps: - uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: "1.20" - uses: actions/checkout@v3 - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. - version: v1.49 + version: v1.51 # Optional: golangci-lint command line arguments. # args: --issues-exit-code=0 diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index e559d18e..5f136933 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -212,7 +212,6 @@ func (d DeviceCode) Token(ctx context.Context) (accesstokens.TokenResponse, erro } var cancel context.CancelFunc - d.Result.ExpiresOn.Sub(time.Now().UTC()) if deadline, ok := ctx.Deadline(); !ok || d.Result.ExpiresOn.Before(deadline) { ctx, cancel = context.WithDeadline(ctx, d.Result.ExpiresOn) } else { From ec469b9c377547820820f6250dd3ac236f2d2d80 Mon Sep 17 00:00:00 2001 From: John Doak <42280444+element-of-surprise@users.noreply.github.com> Date: Tue, 28 Feb 2023 14:49:28 -0800 Subject: [PATCH 12/14] ExportReplace supports Context (#378) --- apps/cache/cache.go | 17 ++- apps/confidential/confidential.go | 12 +- apps/confidential/confidential_test.go | 17 ++- apps/internal/base/base.go | 126 +++++++++++++----- apps/internal/base/base_test.go | 86 +++++++++++- apps/public/public.go | 9 +- apps/public/public_test.go | 24 +++- .../devapps/confidential_auth_code_sample.go | 4 +- apps/tests/devapps/device_code_flow_sample.go | 5 +- apps/tests/devapps/main.go | 7 +- apps/tests/devapps/sample_cache_accessor.go | 15 +-- .../tests/devapps/username_password_sample.go | 7 +- apps/tests/integration/integration_test.go | 7 +- 13 files changed, 249 insertions(+), 87 deletions(-) diff --git a/apps/cache/cache.go b/apps/cache/cache.go index 259ca6d5..84e1383c 100644 --- a/apps/cache/cache.go +++ b/apps/cache/cache.go @@ -11,6 +11,8 @@ implementers on the format being passed. */ package cache +import "context" + // Marshaler marshals data from an internal cache to bytes that can be stored. type Marshaler interface { Marshal() ([]byte, error) @@ -27,13 +29,18 @@ type Serializer interface { Unmarshaler } -// ExportReplace is used export or replace what is in the cache. +// ExportReplace exports and replaces in-memory cache data. It doesn't support nil Context or +// define the outcome of passing one. A Context without a timeout must receive a default timeout +// specified by the implementor. Retries must be implemented inside the implementation. type ExportReplace interface { // Replace replaces the cache with what is in external storage. - // key is the suggested key which can be used for partioning the cache - Replace(cache Unmarshaler, key string) + // key is the suggested key which can be used for partitioning the cache. + // Implementors should honor Context cancellations and return a context.Canceled or + // context.DeadlineExceeded in those cases. + Replace(ctx context.Context, cache Unmarshaler, key string) error // Export writes the binary representation of the cache (cache.Marshal()) to // external storage. This is considered opaque. - // key is the suggested key which can be used for partioning the cache - Export(cache Marshaler, key string) + // key is the suggested key which can be used for partitioning the cache. + // Context cancellations should be honored as in Replace. + Export(ctx context.Context, cache Marshaler, key string) error } diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 49854786..08831dbf 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -50,8 +50,7 @@ duplication. .Net People, Take note on X509: This uses x509.Certificates and private keys. x509 does not store private keys. .Net has some x509.Certificate2 thing that has private keys, but that is just some bullcrap that .Net -added, it doesn't exist in real life. Seriously, "x509.Certificate2", bahahahaha. As such I've -put a PEM decoder into here. +added, it doesn't exist in real life. As such I've put a PEM decoder into here. */ // TODO(msal): This should have example code for each method on client using Go's example doc framework. @@ -702,12 +701,11 @@ func (cca Client) AcquireTokenOnBehalfOf(ctx context.Context, userAssertion stri } // Account gets the account in the token cache with the specified homeAccountID. -func (cca Client) Account(homeAccountID string) Account { - return cca.base.Account(homeAccountID) +func (cca Client) Account(ctx context.Context, accountID string) (Account, error) { + return cca.base.Account(ctx, accountID) } // RemoveAccount signs the account out and forgets account from token cache. -func (cca Client) RemoveAccount(account Account) error { - cca.base.RemoveAccount(account) - return nil +func (cca Client) RemoveAccount(ctx context.Context, account Account) error { + return cca.base.RemoveAccount(ctx, account) } diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index a2f28622..9023bf52 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -303,7 +303,10 @@ func TestAcquireTokenByAuthCode(t *testing.T) { if tk.AccessToken != token { t.Fatalf("unexpected access token %s", tk.AccessToken) } - account := client.Account(tk.Account.HomeAccountID) + account, err := client.Account(context.Background(), tk.Account.HomeAccountID) + if err != nil { + t.Fatal(err) + } if params.utid == "" { if actual := account.HomeAccountID; actual != "123-456.123-456" { t.Fatalf("expected %q, got %q", "123-456.123-456", actual) @@ -630,16 +633,20 @@ func TestTokenProviderOptions(t *testing.T) { // testCache is a simple in-memory cache.ExportReplace implementation type testCache map[string][]byte -func (c testCache) Export(m cache.Marshaler, key string) { - if v, err := m.Marshal(); err == nil { +func (c testCache) Export(ctx context.Context, m cache.Marshaler, key string) error { + v, err := m.Marshal() + if err == nil { c[key] = v } + return err } -func (c testCache) Replace(u cache.Unmarshaler, key string) { +func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, key string) error { + var err error if v, has := c[key]; has { - _ = u.Unmarshal(v) + err = u.Unmarshal(v) } + return err } func TestWithCache(t *testing.T) { diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index ed8715ce..c9a86a0b 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -45,8 +45,12 @@ type partitionedManager interface { type noopCacheAccessor struct{} -func (n noopCacheAccessor) Replace(cache cache.Unmarshaler, key string) {} -func (n noopCacheAccessor) Export(cache cache.Marshaler, key string) {} +func (n noopCacheAccessor) Replace(ctx context.Context, cache cache.Unmarshaler, key string) error { + return nil +} +func (n noopCacheAccessor) Export(ctx context.Context, cache cache.Marshaler, key string) error { + return nil +} // AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache). type AcquireTokenSilentParameters struct { @@ -279,12 +283,12 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s return baseURL.String(), nil } -func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) { +func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (ar AuthResult, err error) { // when tenant == "", the caller didn't specify a tenant and WithTenant will use the client's configured tenant tenant := silent.TenantID authParams, err := b.AuthParams.WithTenant(tenant) if err != nil { - return AuthResult{}, err + return ar, err } authParams.Scopes = silent.Scopes authParams.HomeAccountID = silent.Account.HomeAccountID @@ -296,37 +300,48 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if authParams.AuthorizationType == authority.ATOnBehalfOf { if s, ok := b.pmanager.(cache.Serializer); ok { suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + if err != nil { + return ar, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } storageTokenResponse, err = b.pmanager.Read(ctx, authParams) if err != nil { - return AuthResult{}, err + return ar, err } } else { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + if err != nil { + return ar, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } authParams.AuthorizationType = authority.ATRefreshToken storageTokenResponse, err = b.manager.Read(ctx, authParams, silent.Account) if err != nil { - return AuthResult{}, err + return ar, err } } // ignore cached access tokens when given claims if silent.Claims == "" { - result, err := AuthResultFromStorage(storageTokenResponse) + ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { - return result, nil + return ar, err } } // redeem a cached refresh token, if available if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() { - return AuthResult{}, errors.New("no token found") + err = errors.New("no token found") + return ar, err } var cc *accesstokens.Credential if silent.RequestType == accesstokens.ATConfidential { @@ -335,10 +350,11 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken) if err != nil { - return AuthResult{}, err + return ar, err } - return b.AuthResultFromToken(ctx, authParams, token, true) + ar, err = b.AuthResultFromToken(ctx, authParams, token, true) + return ar, err } func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) { @@ -401,67 +417,103 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq return ar, err } -func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) { +func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (ar AuthResult, err error) { if !cacheWrite { return NewAuthResult(token, shared.Account{}) } var account shared.Account - var err error if authParams.AuthorizationType == authority.ATOnBehalfOf { if s, ok := b.pmanager.(cache.Serializer); ok { suggestedCacheKey := token.CacheKey(authParams) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + if err != nil { + return ar, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } account, err = b.pmanager.Write(authParams, token) if err != nil { - return AuthResult{}, err + return ar, err } } else { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := token.CacheKey(authParams) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + if err != nil { + return ar, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } account, err = b.manager.Write(authParams, token) if err != nil { - return AuthResult{}, err + return ar, err } } - return NewAuthResult(token, account) + ar, err = NewAuthResult(token, account) + return ar, err } -func (b Client) AllAccounts() []shared.Account { +func (b Client) AllAccounts(ctx context.Context) (accts []shared.Account, err error) { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + if err != nil { + return accts, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } - accounts := b.manager.AllAccounts() - return accounts + accts = b.manager.AllAccounts() + return accts, err } -func (b Client) Account(homeAccountID string) shared.Account { +func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared.Account, err error) { authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and .AuthParams is not a pointer. authParams.AuthorizationType = authority.AccountByID authParams.HomeAccountID = homeAccountID if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + if err != nil { + return acct, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } - account := b.manager.Account(homeAccountID) - return account + acct = b.manager.Account(homeAccountID) + return acct, err } // RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account. -func (b Client) RemoveAccount(account shared.Account) { +func (b Client) RemoveAccount(ctx context.Context, account shared.Account) (err error) { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + if err != nil { + return err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } b.manager.RemoveAccount(account, b.AuthParams.ClientID) + return err +} + +// export helps other methods defer exporting the cache after possibly updating its in-memory content. +// err is the error the calling method will return. If err isn't nil, export returns it without +// exporting the cache. +func (b Client) export(ctx context.Context, marshal cache.Marshaler, key string, err error) error { + if err != nil { + return err + } + return b.cacheAccessor.Export(ctx, marshal, key) } diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 33c5c792..f27f2ae9 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -5,11 +5,13 @@ package base import ( "context" + "errors" "fmt" "reflect" "testing" "time" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" @@ -40,8 +42,8 @@ var ( testScopes = []string{"scope"} ) -func fakeClient(t *testing.T) Client { - client, err := New(fakeClientID, fmt.Sprintf("https://%s/%s", fakeAuthority, fakeTenantID), &oauth.Client{}) +func fakeClient(t *testing.T, opts ...Option) Client { + client, err := New(fakeClientID, fmt.Sprintf("https://%s/%s", fakeAuthority, fakeTenantID), &oauth.Client{}, opts...) if err != nil { t.Fatal(err) } @@ -192,6 +194,86 @@ func TestAcquireTokenSilentGrantedScopes(t *testing.T) { } } +// failCache helps tests inject cache I/O errors +type failCache struct { + exported bool + exportErr, replaceErr error +} + +func (c *failCache) Export(context.Context, cache.Marshaler, string) error { + c.exported = true + return c.exportErr +} + +func (c failCache) Replace(context.Context, cache.Unmarshaler, string) error { + return c.replaceErr +} + +func TestCacheIOErrors(t *testing.T) { + ctx := context.Background() + expected := errors.New("cache error") + for _, export := range []bool{true, false} { + name := "replace" + cache := failCache{} + if export { + cache.exportErr = expected + name = "export" + } else { + cache.replaceErr = expected + } + t.Run(name, func(t *testing.T) { + client := fakeClient(t, WithCacheAccessor(&cache)) + _, actual := client.Account(ctx, "...") + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential}) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}}) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{}) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AllAccounts(ctx) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AuthResultFromToken(ctx, authority.AuthParams{}, accesstokens.TokenResponse{}, true) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + actual = client.RemoveAccount(ctx, shared.Account{}) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + }) + } + + // when the client fails to acquire a token, it should return an error instead of exporting the cache + t.Run("auth error", func(t *testing.T) { + cache := failCache{} + client := fakeClient(t, WithCacheAccessor(&cache)) + client.Token.AccessTokens.(*fake.AccessTokens).Err = true + _, err := client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential}) + if err == nil || cache.exported { + t.Fatal("client should have returned an error instead of exporting the cache") + } + _, err = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}}) + if err == nil || cache.exported { + t.Fatal("client should have returned an error instead of exporting the cache") + } + _, err = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{}) + if err == nil || cache.exported { + t.Fatal("client should have returned an error instead of exporting the cache") + } + }) +} + func TestCreateAuthenticationResult(t *testing.T) { future := time.Now().Add(400 * time.Second) diff --git a/apps/public/public.go b/apps/public/public.go index 348bc5fc..cce05277 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -461,14 +461,13 @@ func (pca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redir // Accounts gets all the accounts in the token cache. // If there are no accounts in the cache the returned slice is empty. -func (pca Client) Accounts() []Account { - return pca.base.AllAccounts() +func (pca Client) Accounts(ctx context.Context) ([]Account, error) { + return pca.base.AllAccounts(ctx) } // RemoveAccount signs the account out and forgets account from token cache. -func (pca Client) RemoveAccount(account Account) error { - pca.base.RemoveAccount(account) - return nil +func (pca Client) RemoveAccount(ctx context.Context, account Account) error { + return pca.base.RemoveAccount(ctx, account) } // interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. diff --git a/apps/public/public_test.go b/apps/public/public_test.go index e7faf680..71abbcac 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -146,8 +146,12 @@ func TestAcquireTokenSilentWithTenantID(t *testing.T) { // cache should return the correct access token for each tenant var account Account - if accounts := client.Accounts(); len(accounts) == 2 { - // expecting one account for each tenant we authenticated in above + accounts, err := client.Accounts(ctx) + if err != nil { + t.Fatal(err) + } + // expecting one account for each tenant we authenticated in above + if len(accounts) == 2 { account = accounts[0] } else { t.Fatalf("expected 2 accounts but got %d", len(accounts)) @@ -363,16 +367,19 @@ type testCache struct { store map[string][]byte } -func (c *testCache) Export(m cache.Marshaler, key string) { - if v, err := m.Marshal(); err == nil { +func (c *testCache) Export(ctx context.Context, m cache.Marshaler, key string) error { + v, err := m.Marshal() + if err == nil { c.store[key] = v } + return err } -func (c *testCache) Replace(u cache.Unmarshaler, key string) { +func (c *testCache) Replace(ctx context.Context, u cache.Unmarshaler, key string) error { if v, has := c.store[key]; has { - _ = u.Unmarshal(v) + return u.Unmarshal(v) } + return nil } func TestWithCache(t *testing.T) { @@ -408,7 +415,10 @@ func TestWithCache(t *testing.T) { if err != nil { t.Fatal(err) } - accounts := client.Accounts() + accounts, err := client.Accounts(context.Background()) + if err != nil { + t.Fatal(err) + } if actual := len(accounts); actual != 1 { t.Fatalf("expected 1 account but cache contains %d", actual) } diff --git a/apps/tests/devapps/confidential_auth_code_sample.go b/apps/tests/devapps/confidential_auth_code_sample.go index 924a355a..836f6e7a 100644 --- a/apps/tests/devapps/confidential_auth_code_sample.go +++ b/apps/tests/devapps/confidential_auth_code_sample.go @@ -58,7 +58,7 @@ func getTokenConfidential(w http.ResponseWriter, r *http.Request) { // TODO(msal): Needs to use an x509 certificate like the other now that we are not using a // thumbprint directly. /* -func acquireByAuthorizationCodeConfidential() { +func acquireByAuthorizationCodeConfidential(ctx context.Context) { key, err := os.ReadFile(confidentialConfig.KeyFile) if err != nil { log.Fatal(err) @@ -77,7 +77,7 @@ func acquireByAuthorizationCodeConfidential() { log.Fatal(err) } var userAccount shared.Account - for _, account := range app.Accounts() { + for _, account := range app.Accounts(ctx) { if account.PreferredUsername == confidentialConfig.Username { userAccount = account } diff --git a/apps/tests/devapps/device_code_flow_sample.go b/apps/tests/devapps/device_code_flow_sample.go index bd4d694f..192189cc 100644 --- a/apps/tests/devapps/device_code_flow_sample.go +++ b/apps/tests/devapps/device_code_flow_sample.go @@ -21,7 +21,10 @@ func acquireTokenDeviceCode() { // look in the cache to see if the account to use has been cached var userAccount public.Account - accounts := app.Accounts() + accounts, err := app.Accounts(context.Background()) + if err != nil { + panic("failed to read the cache") + } for _, account := range accounts { if account.PreferredUsername == config.Username { userAccount = account diff --git a/apps/tests/devapps/main.go b/apps/tests/devapps/main.go index ff44dcba..6c1ad04e 100644 --- a/apps/tests/devapps/main.go +++ b/apps/tests/devapps/main.go @@ -1,15 +1,18 @@ package main import ( + "context" "os" ) var ( //config = CreateConfig("config.json") - cacheAccessor = &TokenCache{"serialized_cache.json"} + cacheAccessor = &TokenCache{file: "serialized_cache.json"} ) func main() { + ctx := context.Background() + // TODO(msal): This is pretty yikes. At least we should use the flag package. exampleType := os.Args[1] if exampleType == "1" { @@ -18,7 +21,7 @@ func main() { acquireByAuthorizationCodePublic() */ } else if exampleType == "3" { - acquireByUsernamePasswordPublic() + acquireByUsernamePasswordPublic(ctx) } else if exampleType == "4" { panic("currently not implemented") //acquireByAuthorizationCodeConfidential() diff --git a/apps/tests/devapps/sample_cache_accessor.go b/apps/tests/devapps/sample_cache_accessor.go index 07114a0c..bf4dca91 100644 --- a/apps/tests/devapps/sample_cache_accessor.go +++ b/apps/tests/devapps/sample_cache_accessor.go @@ -4,6 +4,7 @@ package main import ( + "context" "log" "os" @@ -14,24 +15,18 @@ type TokenCache struct { file string } -func (t *TokenCache) Replace(cache cache.Unmarshaler, key string) { +func (t *TokenCache) Replace(ctx context.Context, cache cache.Unmarshaler, key string) error { data, err := os.ReadFile(t.file) if err != nil { log.Println(err) } - err = cache.Unmarshal(data) - if err != nil { - log.Println(err) - } + return cache.Unmarshal(data) } -func (t *TokenCache) Export(cache cache.Marshaler, key string) { +func (t *TokenCache) Export(ctx context.Context, cache cache.Marshaler, key string) error { data, err := cache.Marshal() if err != nil { log.Println(err) } - err = os.WriteFile(t.file, data, 0600) - if err != nil { - log.Println(err) - } + return os.WriteFile(t.file, data, 0600) } diff --git a/apps/tests/devapps/username_password_sample.go b/apps/tests/devapps/username_password_sample.go index 4eb5c484..c5a5b709 100644 --- a/apps/tests/devapps/username_password_sample.go +++ b/apps/tests/devapps/username_password_sample.go @@ -11,7 +11,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) -func acquireByUsernamePasswordPublic() { +func acquireByUsernamePasswordPublic(ctx context.Context) { config := CreateConfig("config.json") app, err := public.New(config.ClientID, public.WithCache(cacheAccessor), public.WithAuthority(config.Authority)) if err != nil { @@ -20,7 +20,10 @@ func acquireByUsernamePasswordPublic() { // look in the cache to see if the account to use has been cached var userAccount public.Account - accounts := app.Accounts() + accounts, err := app.Accounts(ctx) + if err != nil { + panic("failed to read the cache") + } for _, account := range accounts { if account.PreferredUsername == config.Username { userAccount = account diff --git a/apps/tests/integration/integration_test.go b/apps/tests/integration/integration_test.go index 41b4b780..94006386 100644 --- a/apps/tests/integration/integration_test.go +++ b/apps/tests/integration/integration_test.go @@ -390,12 +390,15 @@ func TestRemoveAccount(t *testing.T) { if err != nil { t.Fatalf("TestRemoveAccount: on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", errors.Verbose(err)) } - accounts := app.Accounts() + accounts, err := app.Accounts(ctx) + if err != nil { + t.Fatal(err) + } if len(accounts) == 0 { t.Fatal("TestRemoveAccount: No user accounts found in cache") } testAccount := accounts[0] // Only one account is populated and that is what we will remove. - err = app.RemoveAccount(testAccount) + err = app.RemoveAccount(ctx, testAccount) if err != nil { t.Fatalf("TestRemoveAccount: on RemoveAccount(): got err == %s, want err == nil", errors.Verbose(err)) } From bc40c6db09eab7a1fc1af58469021177ad25bd5a Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Wed, 1 Mar 2023 14:59:56 -0800 Subject: [PATCH 13/14] Add optional metadata for cache.ExportReplace (#395) --- apps/cache/cache.go | 28 +++++++++++++-------- apps/confidential/confidential_test.go | 18 ++++++------- apps/internal/base/base.go | 20 +++++++-------- apps/internal/base/base_test.go | 4 +-- apps/public/public_test.go | 14 +++++------ apps/tests/devapps/sample_cache_accessor.go | 4 +-- 6 files changed, 46 insertions(+), 42 deletions(-) diff --git a/apps/cache/cache.go b/apps/cache/cache.go index 84e1383c..19210883 100644 --- a/apps/cache/cache.go +++ b/apps/cache/cache.go @@ -29,18 +29,26 @@ type Serializer interface { Unmarshaler } +// ExportHints are suggestions for storing data. +type ExportHints struct { + // PartitionKey is a suggested key for partitioning the cache + PartitionKey string +} + +// ReplaceHints are suggestions for loading data. +type ReplaceHints struct { + // PartitionKey is a suggested key for partitioning the cache + PartitionKey string +} + // ExportReplace exports and replaces in-memory cache data. It doesn't support nil Context or // define the outcome of passing one. A Context without a timeout must receive a default timeout // specified by the implementor. Retries must be implemented inside the implementation. type ExportReplace interface { - // Replace replaces the cache with what is in external storage. - // key is the suggested key which can be used for partitioning the cache. - // Implementors should honor Context cancellations and return a context.Canceled or - // context.DeadlineExceeded in those cases. - Replace(ctx context.Context, cache Unmarshaler, key string) error - // Export writes the binary representation of the cache (cache.Marshal()) to - // external storage. This is considered opaque. - // key is the suggested key which can be used for partitioning the cache. - // Context cancellations should be honored as in Replace. - Export(ctx context.Context, cache Marshaler, key string) error + // Replace replaces the cache with what is in external storage. Implementors should honor + // Context cancellations and return context.Canceled or context.DeadlineExceeded in those cases. + Replace(ctx context.Context, cache Unmarshaler, hints ReplaceHints) error + // Export writes the binary representation of the cache (cache.Marshal()) to external storage. + // This is considered opaque. Context cancellations should be honored as in Replace. + Export(ctx context.Context, cache Marshaler, hints ExportHints) error } diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 9023bf52..6870ae95 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -633,20 +633,18 @@ func TestTokenProviderOptions(t *testing.T) { // testCache is a simple in-memory cache.ExportReplace implementation type testCache map[string][]byte -func (c testCache) Export(ctx context.Context, m cache.Marshaler, key string) error { - v, err := m.Marshal() - if err == nil { - c[key] = v +func (c testCache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { + if v, err := m.Marshal(); err == nil { + c[h.PartitionKey] = v } - return err + return nil } -func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, key string) error { - var err error - if v, has := c[key]; has { - err = u.Unmarshal(v) +func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { + if v, has := c[h.PartitionKey]; has { + _ = u.Unmarshal(v) } - return err + return nil } func TestWithCache(t *testing.T) { diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index c9a86a0b..00617abf 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -45,10 +45,10 @@ type partitionedManager interface { type noopCacheAccessor struct{} -func (n noopCacheAccessor) Replace(ctx context.Context, cache cache.Unmarshaler, key string) error { +func (n noopCacheAccessor) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { return nil } -func (n noopCacheAccessor) Export(ctx context.Context, cache cache.Marshaler, key string) error { +func (n noopCacheAccessor) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { return nil } @@ -300,7 +300,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if authParams.AuthorizationType == authority.ATOnBehalfOf { if s, ok := b.pmanager.(cache.Serializer); ok { suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) if err != nil { return ar, err } @@ -315,7 +315,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen } else { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) if err != nil { return ar, err } @@ -426,7 +426,7 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au if authParams.AuthorizationType == authority.ATOnBehalfOf { if s, ok := b.pmanager.(cache.Serializer); ok { suggestedCacheKey := token.CacheKey(authParams) - err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) if err != nil { return ar, err } @@ -441,7 +441,7 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au } else { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := token.CacheKey(authParams) - err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) if err != nil { return ar, err } @@ -461,7 +461,7 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au func (b Client) AllAccounts(ctx context.Context) (accts []shared.Account, err error) { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) if err != nil { return accts, err } @@ -480,7 +480,7 @@ func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared. authParams.HomeAccountID = homeAccountID if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) if err != nil { return acct, err } @@ -496,7 +496,7 @@ func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared. func (b Client) RemoveAccount(ctx context.Context, account shared.Account) (err error) { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - err = b.cacheAccessor.Replace(ctx, s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) if err != nil { return err } @@ -515,5 +515,5 @@ func (b Client) export(ctx context.Context, marshal cache.Marshaler, key string, if err != nil { return err } - return b.cacheAccessor.Export(ctx, marshal, key) + return b.cacheAccessor.Export(ctx, marshal, cache.ExportHints{PartitionKey: key}) } diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index f27f2ae9..7aac8102 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -200,12 +200,12 @@ type failCache struct { exportErr, replaceErr error } -func (c *failCache) Export(context.Context, cache.Marshaler, string) error { +func (c *failCache) Export(context.Context, cache.Marshaler, cache.ExportHints) error { c.exported = true return c.exportErr } -func (c failCache) Replace(context.Context, cache.Unmarshaler, string) error { +func (c failCache) Replace(context.Context, cache.Unmarshaler, cache.ReplaceHints) error { return c.replaceErr } diff --git a/apps/public/public_test.go b/apps/public/public_test.go index 71abbcac..8f245900 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -363,27 +363,25 @@ func TestWithInstanceDiscovery(t *testing.T) { } // testCache is a simple in-memory cache.ExportReplace implementation -type testCache struct { - store map[string][]byte -} +type testCache map[string][]byte -func (c *testCache) Export(ctx context.Context, m cache.Marshaler, key string) error { +func (c testCache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { v, err := m.Marshal() if err == nil { - c.store[key] = v + c[h.PartitionKey] = v } return err } -func (c *testCache) Replace(ctx context.Context, u cache.Unmarshaler, key string) error { - if v, has := c.store[key]; has { +func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { + if v, has := c[h.PartitionKey]; has { return u.Unmarshal(v) } return nil } func TestWithCache(t *testing.T) { - cache := testCache{make(map[string][]byte)} + cache := make(testCache) accessToken, refreshToken := "*", "rt" clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) lmo := "login.microsoftonline.com" diff --git a/apps/tests/devapps/sample_cache_accessor.go b/apps/tests/devapps/sample_cache_accessor.go index bf4dca91..8c52a6c3 100644 --- a/apps/tests/devapps/sample_cache_accessor.go +++ b/apps/tests/devapps/sample_cache_accessor.go @@ -15,7 +15,7 @@ type TokenCache struct { file string } -func (t *TokenCache) Replace(ctx context.Context, cache cache.Unmarshaler, key string) error { +func (t *TokenCache) Replace(ctx context.Context, cache cache.Unmarshaler, hints cache.ReplaceHints) error { data, err := os.ReadFile(t.file) if err != nil { log.Println(err) @@ -23,7 +23,7 @@ func (t *TokenCache) Replace(ctx context.Context, cache cache.Unmarshaler, key s return cache.Unmarshal(data) } -func (t *TokenCache) Export(ctx context.Context, cache cache.Marshaler, key string) error { +func (t *TokenCache) Export(ctx context.Context, cache cache.Marshaler, hints cache.ExportHints) error { data, err := cache.Marshal() if err != nil { log.Println(err) From 9a63cc219c3822f23f3a54c62c1261b729a50536 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 2 Mar 2023 10:18:43 -0800 Subject: [PATCH 14/14] Confidential client requires authority (#394) --- apps/confidential/confidential.go | 46 ++---- apps/confidential/confidential_test.go | 134 ++++++++++++++---- .../internal/oauth/ops/authority/authority.go | 25 ++-- .../devapps/client_certificate_sample.go | 2 +- apps/tests/devapps/client_secret_sample.go | 2 +- apps/tests/integration/integration_test.go | 6 +- 6 files changed, 129 insertions(+), 86 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 08831dbf..6612feb4 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -18,7 +18,6 @@ import ( "encoding/pem" "errors" "fmt" - "net/url" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" @@ -249,27 +248,9 @@ type clientOptions struct { httpClient ops.HTTPClient } -func (o clientOptions) validate() error { - u, err := url.Parse(o.authority) - if err != nil { - return fmt.Errorf("the Authority(%s) does not parse as a valid URL", o.authority) - } - if u.Scheme != "https" { - return fmt.Errorf("the Authority(%s) does not appear to use https", o.authority) - } - return nil -} - // Option is an optional argument to New(). type Option func(o *clientOptions) -// WithAuthority allows you to provide a custom authority for use in the client. -func WithAuthority(authority string) Option { - return func(o *clientOptions) { - o.authority = authority - } -} - // WithCache provides an accessor that will read and write authentication data to an externally managed cache. func WithCache(accessor cache.ExportReplace) Option { return func(o *clientOptions) { @@ -324,37 +305,30 @@ func WithAzureRegion(val string) Option { } } -// New is the constructor for Client. userID is the unique identifier of the user this client -// will store credentials for (a Client is per user). clientID is the Azure clientID and cred is -// the type of credential to use. -func New(clientID string, cred Credential, options ...Option) (Client, error) { +// New is the constructor for Client. authority is the URL of a token authority such as "https://login.microsoftonline.com/". +// If the Client will connect directly to AD FS, use "adfs" for the tenant. clientID is the application's client ID (also called its +// "application ID"). +func New(authority, clientID string, cred Credential, options ...Option) (Client, error) { internalCred, err := cred.toInternal() if err != nil { return Client{}, err } opts := clientOptions{ - authority: base.AuthorityPublicCloud, - httpClient: shared.DefaultClient, + authority: authority, + // if the caller specified a token provider, it will handle all details of authentication, using Client only as a token cache + disableInstanceDiscovery: cred.tokenProvider != nil, + httpClient: shared.DefaultClient, } - for _, o := range options { o(&opts) } - if err := opts.validate(); err != nil { - return Client{}, err - } - baseOpts := []base.Option{ base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), + base.WithInstanceDiscovery(!opts.disableInstanceDiscovery), base.WithRegionDetection(opts.azureRegion), base.WithX5C(opts.sendX5C), - base.WithInstanceDiscovery(!opts.disableInstanceDiscovery), - } - if cred.tokenProvider != nil { - // The caller will handle all details of authentication, using Client only as a token cache. - baseOpts = append(baseOpts, base.WithInstanceDiscovery(false)) } base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), baseOpts...) if err != nil { @@ -480,7 +454,7 @@ func WithClaims(claims string) interface { } } -// WithTenantID specifies a tenant for a single authentication. It may be different than the tenant set in [New] by [WithAuthority]. +// WithTenantID specifies a tenant for a single authentication. It may be different than the tenant set in [New]. // This option is valid for any token acquisition method. func WithTenantID(tenantID string) interface { AcquireByAuthCodeOption diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 6870ae95..1d529269 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -65,7 +65,9 @@ func TestCertFromPEM(t *testing.T) { const ( authorityFmt = "https://%s/%s" + fakeAuthority = "https://fake_authority/fake" fakeClientID = "fake_client_id" + fakeSecret = "fake_secret" fakeTokenEndpoint = "https://fake_authority/fake/token" localhost = "http://localhost" refresh = "fake_refresh" @@ -75,8 +77,7 @@ const ( var tokenScope = []string{"the_scope"} func fakeClient(tk accesstokens.TokenResponse, credential Credential, options ...Option) (Client, error) { - options = append(options, WithAuthority("https://fake_authority/fake")) - client, err := New("fake_client_id", credential, options...) + client, err := New(fakeAuthority, fakeClientID, credential, options...) if err != nil { return Client{}, err } @@ -164,19 +165,20 @@ func TestAcquireTokenByCredential(t *testing.T) { func TestAcquireTokenOnBehalfOf(t *testing.T) { // this test is an offline version of TestOnBehalfOf in integration_test.go - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } lmo := "login.microsoftonline.com" + tenant := "tenant" assertion := "assertion" mockClient := mock.Client{} // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 - mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, "common"))) - mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "common"))) + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 3600))) - client, err := New("clientID", cred, WithHTTPClient(&mockClient)) + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -248,7 +250,7 @@ func TestAcquireTokenByAssertionCallback(t *testing.T) { } func TestAcquireTokenByAuthCode(t *testing.T) { - cred, err := NewCredFromSecret("fake_secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -332,7 +334,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) { } func TestAcquireTokenSilentTenants(t *testing.T) { - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -340,7 +342,7 @@ func TestAcquireTokenSilentTenants(t *testing.T) { lmo := "login.microsoftonline.com" mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenants[0]))) - client, err := New(fakeClientID, cred, WithHTTPClient(&mockClient)) + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenants[0]), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -368,13 +370,28 @@ func TestAcquireTokenSilentTenants(t *testing.T) { } } +func TestAuthorityValidation(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + for _, a := range []string{"", "https://login.microsoftonline.com", "http://login.microsoftonline.com/tenant"} { + t.Run(a, func(t *testing.T) { + _, err := New(a, fakeClientID, cred) + if err == nil || !strings.Contains(err.Error(), "authority") { + t.Fatalf("expected an error about the invalid authority, got %v", err) + } + }) + } +} + func TestInvalidCredential(t *testing.T) { for _, cred := range []Credential{ {}, NewCredFromAssertionCallback(nil), } { t.Run("", func(t *testing.T) { - _, err := New(fakeClientID, cred) + _, err := New(fakeAuthority, fakeClientID, cred) if err == nil { t.Fatal("expected an error") } @@ -565,7 +582,7 @@ func TestNewCredFromTokenProvider(t *testing.T) { ExpiresInSeconds: expiresIn, }, nil }) - client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } @@ -596,7 +613,7 @@ func TestNewCredFromTokenProviderError(t *testing.T) { cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) { return exported.TokenProviderResult{}, errors.New(expectedError) }) - client, err := New(fakeClientID, cred) + client, err := New(fakeAuthority, fakeClientID, cred) if err != nil { t.Fatal(err) } @@ -617,7 +634,7 @@ func TestTokenProviderOptions(t *testing.T) { } return TokenProviderResult{AccessToken: accessToken, ExpiresInSeconds: 3600}, nil }) - client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } @@ -657,11 +674,11 @@ func TestWithCache(t *testing.T) { mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600))) - cred, err := NewCredFromSecret("...") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } - client, err := New(fakeClientID, cred, WithAuthority(authorityA), WithCache(&cache), WithHTTPClient(&mockClient)) + client, err := New(authorityA, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -679,7 +696,7 @@ func TestWithCache(t *testing.T) { } // a client configured for a different tenant should be able to authenticate silently with the shared cache's data - client, err = New(fakeClientID, cred, WithAuthority(authorityB), WithCache(&cache), WithHTTPClient(&mockClient)) + client, err = New(authorityB, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -700,13 +717,13 @@ func TestWithCache(t *testing.T) { } func TestWithClaims(t *testing.T) { - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } accessToken := "at" lmo, tenant := "login.microsoftonline.com", "tenant" - authority := fmt.Sprintf("https://%s/%s", lmo, tenant) + authority := fmt.Sprintf(authorityFmt, lmo, tenant) for _, test := range []struct { capabilities []string claims, expected string @@ -774,7 +791,7 @@ func TestWithClaims(t *testing.T) { validate(t, r.Form) }), ) - client, err := New(fakeClientID, cred, WithAuthority(authority), WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient)) + client, err := New(authority, fakeClientID, cred, WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -873,7 +890,7 @@ func TestWithTenantID(t *testing.T) { } { for _, method := range []string{"authcode", "authcodeURL", "credential", "obo"} { t.Run(method, func(t *testing.T) { - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -891,7 +908,7 @@ func TestWithTenantID(t *testing.T) { mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) - client, err := New(fakeClientID, cred, WithAuthority(test.authority), WithHTTPClient(&mockClient)) + client, err := New(test.authority, fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -955,6 +972,65 @@ func TestWithTenantID(t *testing.T) { }) } } + + // if every auth call specifies a different tenant, Client shouldn't send requests to its configured authority + t.Run("enables fake authority", func(t *testing.T) { + host := "host" + defaultTenant := "default" + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + URL := "" + mockClient := mock.Client{} + client, err := New(fmt.Sprintf(authorityFmt, host, defaultTenant), fakeClientID, cred, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + checkForWrongTenant := func(r *http.Request) { + if u := r.URL.String(); strings.Contains(u, defaultTenant) { + t.Fatalf("unexpected request to the default authority: %q", u) + } + } + ctx := context.Background() + for i := 0; i < 3; i++ { + tenant := fmt.Sprint(i) + expected := fmt.Sprintf(authorityFmt, host, tenant) + // TODO: prevent redundant discovery requests https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) + mockClient.AppendResponse( + mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "", "", 3600)), + mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), + ) + if i == 0 { + // TODO: see above (first silent auth rediscovers instance metadata) + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) + } + ar, err := client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(tenant)) + if err != nil { + t.Fatal(err) + } + if !strings.HasPrefix(URL, expected) { + t.Fatalf(`expected "%s", got "%s"`, expected, URL) + } + if ar.AccessToken != accessToken { + t.Fatalf("unexpected access token %q", ar.AccessToken) + } + // silent authentication should now succeed for the given tenant... + if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatal("cached access token should match the one returned by AcquireToken...") + } + // ...but fail for another tenant + otherTenant := "not-" + tenant + if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(otherTenant)); err == nil { + t.Fatal("expected an error") + } + } + }) } func TestWithInstanceDiscovery(t *testing.T) { @@ -968,7 +1044,7 @@ func TestWithInstanceDiscovery(t *testing.T) { for _, method := range []string{"authcode", "credential", "obo"} { t.Run(method, func(t *testing.T) { authority := stackurl + tenant - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -982,7 +1058,7 @@ func TestWithInstanceDiscovery(t *testing.T) { mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), ) - client, err := New(fakeClientID, cred, WithAuthority(authority), WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) + client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) if err != nil { t.Fatal(err) } @@ -1029,7 +1105,7 @@ func TestWithPortAuthority(t *testing.T) { host := sl + port tenant := "00000000-0000-0000-0000-000000000000" authority := fmt.Sprintf("https://%s%s/%s", sl, port, tenant) - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -1043,7 +1119,7 @@ func TestWithPortAuthority(t *testing.T) { mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) - client, err := New(fakeClientID, cred, WithAuthority(authority), WithHTTPClient(&mockClient)) + client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -1072,11 +1148,11 @@ func TestWithPortAuthority(t *testing.T) { func TestWithLoginHint(t *testing.T) { upn := "user@localhost" - cred, err := NewCredFromSecret("...") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } - client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } @@ -1112,11 +1188,11 @@ func TestWithLoginHint(t *testing.T) { func TestWithDomainHint(t *testing.T) { domain := "contoso.com" - cred, err := NewCredFromSecret("...") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } - client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } diff --git a/apps/internal/oauth/ops/authority/authority.go b/apps/internal/oauth/ops/authority/authority.go index de5f053f..5bebdb8e 100644 --- a/apps/internal/oauth/ops/authority/authority.go +++ b/apps/internal/oauth/ops/authority/authority.go @@ -309,31 +309,24 @@ func firstPathSegment(u *url.URL) (string, error) { return pathParts[1], nil } - return "", errors.New("authority does not have two segments") + return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) } // NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided. -func NewInfoFromAuthorityURI(authorityURI string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) { - authorityURI = strings.ToLower(authorityURI) - var authorityType string - u, err := url.Parse(authorityURI) - if err != nil { - return Info{}, fmt.Errorf("authorityURI passed could not be parsed: %w", err) - } - if u.Scheme != "https" { - return Info{}, fmt.Errorf("authorityURI(%s) must have scheme https", authorityURI) +func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) { + u, err := url.Parse(strings.ToLower(authority)) + if err != nil || u.Scheme != "https" { + return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) } tenant, err := firstPathSegment(u) - if tenant == "adfs" { - authorityType = ADFS - } else { - authorityType = AAD - } - if err != nil { return Info{}, err } + authorityType := AAD + if tenant == "adfs" { + authorityType = ADFS + } // u.Host includes the port, if any, which is required for private cloud deployments return Info{ diff --git a/apps/tests/devapps/client_certificate_sample.go b/apps/tests/devapps/client_certificate_sample.go index c9bc2e8e..11f3acef 100644 --- a/apps/tests/devapps/client_certificate_sample.go +++ b/apps/tests/devapps/client_certificate_sample.go @@ -30,7 +30,7 @@ func acquireTokenClientCertificate() { if err != nil { log.Fatal(err) } - app, err := confidential.New(config.ClientID, cred, confidential.WithAuthority(config.Authority), confidential.WithCache(cacheAccessor)) + app, err := confidential.New(config.Authority, config.ClientID, cred, confidential.WithCache(cacheAccessor)) if err != nil { log.Fatal(err) } diff --git a/apps/tests/devapps/client_secret_sample.go b/apps/tests/devapps/client_secret_sample.go index 03f8d697..28918486 100644 --- a/apps/tests/devapps/client_secret_sample.go +++ b/apps/tests/devapps/client_secret_sample.go @@ -17,7 +17,7 @@ func acquireTokenClientSecret() { if err != nil { log.Fatal(err) } - app, err := confidential.New(config.ClientID, cred, confidential.WithAuthority(config.Authority), confidential.WithCache(cacheAccessor)) + app, err := confidential.New(config.Authority, config.ClientID, cred, confidential.WithCache(cacheAccessor)) if err != nil { log.Fatal(err) } diff --git a/apps/tests/integration/integration_test.go b/apps/tests/integration/integration_test.go index 94006386..55cbf1ef 100644 --- a/apps/tests/integration/integration_test.go +++ b/apps/tests/integration/integration_test.go @@ -97,7 +97,7 @@ func newLabClient() (*labClient, error) { return nil, fmt.Errorf("could not create a cred from a secret: %w", err) } - app, err := confidential.New(clientID, cred, confidential.WithAuthority(microsoftAuthority)) + app, err := confidential.New(microsoftAuthority, clientID, cred) if err != nil { return nil, err } @@ -227,7 +227,7 @@ func TestConfidentialClientwithSecret(t *testing.T) { panic(errors.Verbose(err)) } - app, err := confidential.New(clientID, cred, confidential.WithAuthority(microsoftAuthority)) + app, err := confidential.New(microsoftAuthority, clientID, cred) if err != nil { panic(errors.Verbose(err)) } @@ -290,7 +290,7 @@ func TestOnBehalfOf(t *testing.T) { if err != nil { panic(errors.Verbose(err)) } - cca, err := confidential.New(ccaClientID, cred) + cca, err := confidential.New("https://login.microsoftonline.com/common", ccaClientID, cred) if err != nil { panic(errors.Verbose(err)) }