diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 374451e7..850ad671 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: - go: ["1.19", "1.18"] + go: ["1.19", "1.20"] steps: - name: Set up Go 1.x diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 94fae709..acf23dfc 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -18,13 +18,13 @@ jobs: steps: - uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: "1.20" - uses: actions/checkout@v3 - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: # Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version. - version: v1.49 + version: v1.51 # Optional: golangci-lint command line arguments. # args: --issues-exit-code=0 diff --git a/apps/cache/cache.go b/apps/cache/cache.go index 259ca6d5..19210883 100644 --- a/apps/cache/cache.go +++ b/apps/cache/cache.go @@ -11,6 +11,8 @@ implementers on the format being passed. */ package cache +import "context" + // Marshaler marshals data from an internal cache to bytes that can be stored. type Marshaler interface { Marshal() ([]byte, error) @@ -27,13 +29,26 @@ type Serializer interface { Unmarshaler } -// ExportReplace is used export or replace what is in the cache. +// ExportHints are suggestions for storing data. +type ExportHints struct { + // PartitionKey is a suggested key for partitioning the cache + PartitionKey string +} + +// ReplaceHints are suggestions for loading data. +type ReplaceHints struct { + // PartitionKey is a suggested key for partitioning the cache + PartitionKey string +} + +// ExportReplace exports and replaces in-memory cache data. It doesn't support nil Context or +// define the outcome of passing one. A Context without a timeout must receive a default timeout +// specified by the implementor. Retries must be implemented inside the implementation. type ExportReplace interface { - // Replace replaces the cache with what is in external storage. - // key is the suggested key which can be used for partioning the cache - Replace(cache Unmarshaler, key string) - // Export writes the binary representation of the cache (cache.Marshal()) to - // external storage. This is considered opaque. - // key is the suggested key which can be used for partioning the cache - Export(cache Marshaler, key string) + // Replace replaces the cache with what is in external storage. Implementors should honor + // Context cancellations and return context.Canceled or context.DeadlineExceeded in those cases. + Replace(ctx context.Context, cache Unmarshaler, hints ReplaceHints) error + // Export writes the binary representation of the cache (cache.Marshal()) to external storage. + // This is considered opaque. Context cancellations should be honored as in Replace. + Export(ctx context.Context, cache Marshaler, hints ExportHints) error } diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 85a1ba6d..6612feb4 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -18,7 +18,6 @@ import ( "encoding/pem" "errors" "fmt" - "net/url" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" @@ -50,8 +49,7 @@ duplication. .Net People, Take note on X509: This uses x509.Certificates and private keys. x509 does not store private keys. .Net has some x509.Certificate2 thing that has private keys, but that is just some bullcrap that .Net -added, it doesn't exist in real life. Seriously, "x509.Certificate2", bahahahaha. As such I've -put a PEM decoder into here. +added, it doesn't exist in real life. As such I've put a PEM decoder into here. */ // TODO(msal): This should have example code for each method on client using Go's example doc framework. @@ -63,7 +61,7 @@ type AuthResult = base.AuthResult type Account = shared.Account -// CertFromPEM converts a PEM file (.pem or .key) for use with NewCredFromCert(). The file +// CertFromPEM converts a PEM file (.pem or .key) for use with [NewCredFromCert]. The file // must contain the public certificate and the private key. If a PEM block is encrypted and // password is not an empty string, it attempts to decrypt the PEM blocks using the password. // Multiple certs are due to certificate chaining for use cases like TLS that sign from root to leaf. @@ -179,33 +177,15 @@ func NewCredFromSecret(secret string) (Credential, error) { return Credential{secret: secret}, nil } -// NewCredFromAssertion creates a Credential from a signed assertion. -// -// Deprecated: a Credential created by this function can't refresh the -// assertion when it expires. Use NewCredFromAssertionCallback instead. -func NewCredFromAssertion(assertion string) (Credential, error) { - if assertion == "" { - return Credential{}, errors.New("assertion can't be empty string") - } - return NewCredFromAssertionCallback(func(context.Context, AssertionRequestOptions) (string, error) { return assertion, nil }), nil -} - // NewCredFromAssertionCallback creates a Credential that invokes a callback to get assertions // authenticating the application. The callback must be thread safe. func NewCredFromAssertionCallback(callback func(context.Context, AssertionRequestOptions) (string, error)) Credential { return Credential{assertionCallback: callback} } -// NewCredFromCert creates a Credential from an x509.Certificate and an RSA private key. -// CertFromPEM() can be used to get these values from a PEM file. -func NewCredFromCert(cert *x509.Certificate, key crypto.PrivateKey) Credential { - cred, _ := NewCredFromCertChain([]*x509.Certificate{cert}, key) - return cred -} - -// NewCredFromCertChain creates a Credential from a chain of x509.Certificates and an RSA private key -// as returned by CertFromPEM(). -func NewCredFromCertChain(certs []*x509.Certificate, key crypto.PrivateKey) (Credential, error) { +// NewCredFromCert creates a Credential from a certificate or chain of certificates and an RSA private key +// as returned by [CertFromPEM]. +func NewCredFromCert(certs []*x509.Certificate, key crypto.PrivateKey) (Credential, error) { cred := Credential{key: key} k, ok := key.(*rsa.PrivateKey) if !ok { @@ -255,73 +235,32 @@ func AutoDetectRegion() string { // For more information, visit https://docs.microsoft.com/azure/active-directory/develop/msal-client-applications type Client struct { base base.Client - cred *accesstokens.Credential - - // userID is some unique identifier for a user. It actually isn't used by us at all, it - // simply acts as another hint that a confidential.Client is for a single user. - userID string } -// Options are optional settings for New(). These options are set using various functions +// clientOptions are optional settings for New(). These options are set using various functions // returning Option calls. -type Options struct { - // Accessor controls cache persistence. - // By default there is no cache persistence. This can be set using the WithAccessor() option. - Accessor cache.ExportReplace - - // The host of the Azure Active Directory authority. - // The default is https://login.microsoftonline.com/common. This can be changed using the - // WithAuthority() option. - Authority string - - // The HTTP client used for making requests. - // It defaults to a shared http.Client. - HTTPClient ops.HTTPClient - - // SendX5C specifies if x5c claim(public key of the certificate) should be sent to STS. - SendX5C bool - - // Instructs MSAL Go to use an Azure regional token service with sepcified AzureRegion. - AzureRegion string - - capabilities []string - - disableInstanceDiscovery bool -} - -func (o Options) validate() error { - u, err := url.Parse(o.Authority) - if err != nil { - return fmt.Errorf("the Authority(%s) does not parse as a valid URL", o.Authority) - } - if u.Scheme != "https" { - return fmt.Errorf("the Authority(%s) does not appear to use https", o.Authority) - } - return nil +type clientOptions struct { + accessor cache.ExportReplace + authority, azureRegion string + capabilities []string + disableInstanceDiscovery, sendX5C bool + httpClient ops.HTTPClient } // Option is an optional argument to New(). -type Option func(o *Options) +type Option func(o *clientOptions) -// WithAuthority allows you to provide a custom authority for use in the client. -func WithAuthority(authority string) Option { - return func(o *Options) { - o.Authority = authority - } -} - -// WithAccessor provides a cache accessor that will read and write to some externally managed cache -// that may or may not be shared with other applications. -func WithAccessor(accessor cache.ExportReplace) Option { - return func(o *Options) { - o.Accessor = accessor +// WithCache provides an accessor that will read and write authentication data to an externally managed cache. +func WithCache(accessor cache.ExportReplace) Option { + return func(o *clientOptions) { + o.accessor = accessor } } // WithClientCapabilities allows configuring one or more client capabilities such as "CP1" func WithClientCapabilities(capabilities []string) Option { - return func(o *Options) { + return func(o *clientOptions) { // there's no danger of sharing the slice's underlying memory with the application because // this slice is simply passed to base.WithClientCapabilities, which copies its data o.capabilities = capabilities @@ -330,21 +269,21 @@ func WithClientCapabilities(capabilities []string) Option { // WithHTTPClient allows for a custom HTTP client to be set. func WithHTTPClient(httpClient ops.HTTPClient) Option { - return func(o *Options) { - o.HTTPClient = httpClient + return func(o *clientOptions) { + o.httpClient = httpClient } } // WithX5C specifies if x5c claim(public key of the certificate) should be sent to STS to enable Subject Name Issuer Authentication. func WithX5C() Option { - return func(o *Options) { - o.SendX5C = true + return func(o *clientOptions) { + o.sendX5C = true } } // WithInstanceDiscovery set to false to disable authority validation (to support private cloud scenarios) func WithInstanceDiscovery(enabled bool) Option { - return func(o *Options) { + return func(o *clientOptions) { o.disableInstanceDiscovery = !enabled } } @@ -361,44 +300,37 @@ func WithInstanceDiscovery(enabled bool) Option { // If auto-detection fails, the non-regional endpoint will be used. // If an invalid region name is provided, the non-regional endpoint MIGHT be used or the token request MIGHT fail. func WithAzureRegion(val string) Option { - return func(o *Options) { - o.AzureRegion = val + return func(o *clientOptions) { + o.azureRegion = val } } -// New is the constructor for Client. userID is the unique identifier of the user this client -// will store credentials for (a Client is per user). clientID is the Azure clientID and cred is -// the type of credential to use. -func New(clientID string, cred Credential, options ...Option) (Client, error) { +// New is the constructor for Client. authority is the URL of a token authority such as "https://login.microsoftonline.com/". +// If the Client will connect directly to AD FS, use "adfs" for the tenant. clientID is the application's client ID (also called its +// "application ID"). +func New(authority, clientID string, cred Credential, options ...Option) (Client, error) { internalCred, err := cred.toInternal() if err != nil { return Client{}, err } - opts := Options{ - Authority: base.AuthorityPublicCloud, - HTTPClient: shared.DefaultClient, + opts := clientOptions{ + authority: authority, + // if the caller specified a token provider, it will handle all details of authentication, using Client only as a token cache + disableInstanceDiscovery: cred.tokenProvider != nil, + httpClient: shared.DefaultClient, } - for _, o := range options { o(&opts) } - if err := opts.validate(); err != nil { - return Client{}, err - } - baseOpts := []base.Option{ - base.WithCacheAccessor(opts.Accessor), + base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), - base.WithRegionDetection(opts.AzureRegion), - base.WithX5C(opts.SendX5C), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery), + base.WithRegionDetection(opts.azureRegion), + base.WithX5C(opts.sendX5C), } - if cred.tokenProvider != nil { - // The caller will handle all details of authentication, using Client only as a token cache. - baseOpts = append(baseOpts, base.WithInstanceDiscovery(false)) - } - base, err := base.New(clientID, opts.Authority, oauth.New(opts.HTTPClient), baseOpts...) + base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), baseOpts...) if err != nil { return Client{}, err } @@ -407,11 +339,6 @@ func New(clientID string, cred Credential, options ...Option) (Client, error) { return Client{base: base, cred: internalCred}, nil } -// UserID is the unique user identifier this client if for. -func (cca Client) UserID() string { - return cca.userID -} - // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { claims, loginHint, tenantID, domainHint string @@ -508,13 +435,13 @@ func WithClaims(claims string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: + case *acquireTokenByAuthCodeOptions: t.claims = claims case *acquireTokenByCredentialOptions: t.claims = claims case *acquireTokenOnBehalfOfOptions: t.claims = claims - case *AcquireTokenSilentOptions: + case *acquireTokenSilentOptions: t.claims = claims case *authCodeURLOptions: t.claims = claims @@ -527,7 +454,7 @@ func WithClaims(claims string) interface { } } -// WithTenantID specifies a tenant for a single authentication. It may be different than the tenant set in [New] by [WithAuthority]. +// WithTenantID specifies a tenant for a single authentication. It may be different than the tenant set in [New]. // This option is valid for any token acquisition method. func WithTenantID(tenantID string) interface { AcquireByAuthCodeOption @@ -548,13 +475,13 @@ func WithTenantID(tenantID string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: + case *acquireTokenByAuthCodeOptions: t.tenantID = tenantID case *acquireTokenByCredentialOptions: t.tenantID = tenantID case *acquireTokenOnBehalfOfOptions: t.tenantID = tenantID - case *AcquireTokenSilentOptions: + case *acquireTokenSilentOptions: t.tenantID = tenantID case *authCodeURLOptions: t.tenantID = tenantID @@ -567,12 +494,10 @@ func WithTenantID(tenantID string) interface { } } -// AcquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. +// acquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. // These are set by using various AcquireTokenSilentOption functions. -type AcquireTokenSilentOptions struct { - // Account represents the account to use. To set, use the WithSilentAccount() option. - Account Account - +type acquireTokenSilentOptions struct { + account Account claims, tenantID string } @@ -581,11 +506,6 @@ type AcquireSilentOption interface { acquireSilentOption() } -// AcquireTokenSilentOption changes options inside AcquireTokenSilentOptions used in .AcquireTokenSilent(). -type AcquireTokenSilentOption func(a *AcquireTokenSilentOptions) - -func (AcquireTokenSilentOption) acquireSilentOption() {} - // WithSilentAccount uses the passed account during an AcquireTokenSilent() call. func WithSilentAccount(account Account) interface { AcquireSilentOption @@ -598,8 +518,8 @@ func WithSilentAccount(account Account) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenSilentOptions: - t.Account = account + case *acquireTokenSilentOptions: + t.account = account default: return fmt.Errorf("unexpected options type %T", a) } @@ -613,7 +533,7 @@ func WithSilentAccount(account Account) interface { // // Options: [WithClaims], [WithSilentAccount], [WithTenantID] func (cca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts ...AcquireSilentOption) (AuthResult, error) { - o := AcquireTokenSilentOptions{} + o := acquireTokenSilentOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } @@ -624,21 +544,19 @@ func (cca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts silentParameters := base.AcquireTokenSilentParameters{ Scopes: scopes, - Account: o.Account, + Account: o.account, RequestType: accesstokens.ATConfidential, Credential: cca.cred, - IsAppCache: o.Account.IsZero(), + IsAppCache: o.account.IsZero(), TenantID: o.tenantID, } return cca.base.AcquireTokenSilent(ctx, silentParameters) } -// AcquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. -type AcquireTokenByAuthCodeOptions struct { - Challenge string - - claims, tenantID string +// acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. +type acquireTokenByAuthCodeOptions struct { + challenge, claims, tenantID string } // AcquireByAuthCodeOption is implemented by options for AcquireTokenByAuthCode @@ -646,11 +564,6 @@ type AcquireByAuthCodeOption interface { acquireByAuthCodeOption() } -// AcquireTokenByAuthCodeOption changes options inside AcquireTokenByAuthCodeOptions used in .AcquireTokenByAuthCode(). -type AcquireTokenByAuthCodeOption func(a *AcquireTokenByAuthCodeOptions) - -func (AcquireTokenByAuthCodeOption) acquireByAuthCodeOption() {} - // WithChallenge allows you to provide a challenge for the .AcquireTokenByAuthCode() call. func WithChallenge(challenge string) interface { AcquireByAuthCodeOption @@ -663,8 +576,8 @@ func WithChallenge(challenge string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: - t.Challenge = challenge + case *acquireTokenByAuthCodeOptions: + t.challenge = challenge default: return fmt.Errorf("unexpected options type %T", a) } @@ -679,7 +592,7 @@ func WithChallenge(challenge string) interface { // // Options: [WithChallenge], [WithClaims], [WithTenantID] func (cca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, opts ...AcquireByAuthCodeOption) (AuthResult, error) { - o := AcquireTokenByAuthCodeOptions{} + o := acquireTokenByAuthCodeOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } @@ -687,7 +600,7 @@ func (cca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redir params := base.AcquireTokenAuthCodeParameters{ Scopes: scopes, Code: code, - Challenge: o.Challenge, + Challenge: o.challenge, Claims: o.claims, AppType: accesstokens.ATConfidential, Credential: cca.cred, // This setting differs from public.Client.AcquireTokenByAuthCode @@ -762,12 +675,11 @@ func (cca Client) AcquireTokenOnBehalfOf(ctx context.Context, userAssertion stri } // Account gets the account in the token cache with the specified homeAccountID. -func (cca Client) Account(homeAccountID string) Account { - return cca.base.Account(homeAccountID) +func (cca Client) Account(ctx context.Context, accountID string) (Account, error) { + return cca.base.Account(ctx, accountID) } // RemoveAccount signs the account out and forgets account from token cache. -func (cca Client) RemoveAccount(account Account) error { - cca.base.RemoveAccount(account) - return nil +func (cca Client) RemoveAccount(ctx context.Context, account Account) error { + return cca.base.RemoveAccount(ctx, account) } diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 65b1c6f0..1d529269 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" @@ -31,8 +32,6 @@ import ( "github.com/kylelemons/godebug/pretty" ) -const localhost = "http://localhost" - // errorClient is an HTTP client for tests that should fail when confidential.Client sends a request type errorClient struct{} @@ -65,17 +64,20 @@ func TestCertFromPEM(t *testing.T) { } const ( + authorityFmt = "https://%s/%s" + fakeAuthority = "https://fake_authority/fake" fakeClientID = "fake_client_id" + fakeSecret = "fake_secret" fakeTokenEndpoint = "https://fake_authority/fake/token" - token = "fake_token" + localhost = "http://localhost" refresh = "fake_refresh" + token = "fake_token" ) var tokenScope = []string{"the_scope"} func fakeClient(tk accesstokens.TokenResponse, credential Credential, options ...Option) (Client, error) { - options = append(options, WithAuthority("https://fake_authority/fake")) - client, err := New("fake_client_id", credential, options...) + client, err := New(fakeAuthority, fakeClientID, credential, options...) if err != nil { return Client{}, err } @@ -163,19 +165,20 @@ func TestAcquireTokenByCredential(t *testing.T) { func TestAcquireTokenOnBehalfOf(t *testing.T) { // this test is an offline version of TestOnBehalfOf in integration_test.go - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } lmo := "login.microsoftonline.com" + tenant := "tenant" assertion := "assertion" mockClient := mock.Client{} // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 - mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, "common"))) - mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "common"))) + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 3600))) - client, err := New("clientID", cred, WithHTTPClient(&mockClient)) + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -247,7 +250,7 @@ func TestAcquireTokenByAssertionCallback(t *testing.T) { } func TestAcquireTokenByAuthCode(t *testing.T) { - cred, err := NewCredFromSecret("fake_secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -302,7 +305,10 @@ func TestAcquireTokenByAuthCode(t *testing.T) { if tk.AccessToken != token { t.Fatalf("unexpected access token %s", tk.AccessToken) } - account := client.Account(tk.Account.HomeAccountID) + account, err := client.Account(context.Background(), tk.Account.HomeAccountID) + if err != nil { + t.Fatal(err) + } if params.utid == "" { if actual := account.HomeAccountID; actual != "123-456.123-456" { t.Fatalf("expected %q, got %q", "123-456.123-456", actual) @@ -328,7 +334,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) { } func TestAcquireTokenSilentTenants(t *testing.T) { - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -336,7 +342,7 @@ func TestAcquireTokenSilentTenants(t *testing.T) { lmo := "login.microsoftonline.com" mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenants[0]))) - client, err := New(fakeClientID, cred, WithHTTPClient(&mockClient)) + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenants[0]), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -364,24 +370,28 @@ func TestAcquireTokenSilentTenants(t *testing.T) { } } -func TestInvalidCredential(t *testing.T) { - data, err := os.ReadFile("../testdata/test-cert.pem") +func TestAuthorityValidation(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } - certs, key, err := CertFromPEM(data, "") - if err != nil { - t.Fatal(err) + for _, a := range []string{"", "https://login.microsoftonline.com", "http://login.microsoftonline.com/tenant"} { + t.Run(a, func(t *testing.T) { + _, err := New(a, fakeClientID, cred) + if err == nil || !strings.Contains(err.Error(), "authority") { + t.Fatalf("expected an error about the invalid authority, got %v", err) + } + }) } +} + +func TestInvalidCredential(t *testing.T) { for _, cred := range []Credential{ {}, NewCredFromAssertionCallback(nil), - NewCredFromCert(nil, nil), - NewCredFromCert(certs[0], nil), - NewCredFromCert(nil, key), } { t.Run("", func(t *testing.T) { - _, err := New(fakeClientID, cred) + _, err := New(fakeAuthority, fakeClientID, cred) if err == nil { t.Fatal("expected an error") } @@ -389,7 +399,7 @@ func TestInvalidCredential(t *testing.T) { } } -func TestNewCredFromCertChain(t *testing.T) { +func TestNewCredFromCert(t *testing.T) { for _, file := range []struct { path string numCerts int @@ -423,7 +433,7 @@ func TestNewCredFromCertChain(t *testing.T) { t.Fatal("expected an RSA private key") } verifyingKey := &k.PublicKey - cred, err := NewCredFromCertChain(certs, key) + cred, err := NewCredFromCert(certs, key) if err != nil { t.Fatal(err) } @@ -506,7 +516,7 @@ func TestNewCredFromCertChain(t *testing.T) { } } -func TestNewCredFromCertChainError(t *testing.T) { +func TestNewCredFromCertError(t *testing.T) { data, err := os.ReadFile("../testdata/test-cert.pem") if err != nil { t.Fatal(err) @@ -528,12 +538,23 @@ func TestNewCredFromCertChainError(t *testing.T) { {[]*x509.Certificate{nil}, key}, } { t.Run("", func(t *testing.T) { - _, err := NewCredFromCertChain(test.certs, test.key) + _, err := NewCredFromCert(test.certs, test.key) if err == nil { t.Fatal("expected an error") } }) } + + // the key in this file doesn't match the cert loaded above + if data, err = os.ReadFile("../testdata/test-cert-chain.pem"); err != nil { + t.Fatal(err) + } + if _, key, err = CertFromPEM(data, ""); err != nil { + t.Fatal(err) + } + if _, err = NewCredFromCert(certs, key); err == nil { + t.Fatal("expected an error because key doesn't match certs") + } } func TestNewCredFromTokenProvider(t *testing.T) { @@ -561,7 +582,7 @@ func TestNewCredFromTokenProvider(t *testing.T) { ExpiresInSeconds: expiresIn, }, nil }) - client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } @@ -592,7 +613,7 @@ func TestNewCredFromTokenProviderError(t *testing.T) { cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) { return exported.TokenProviderResult{}, errors.New(expectedError) }) - client, err := New(fakeClientID, cred) + client, err := New(fakeAuthority, fakeClientID, cred) if err != nil { t.Fatal(err) } @@ -613,7 +634,7 @@ func TestTokenProviderOptions(t *testing.T) { } return TokenProviderResult{AccessToken: accessToken, ExpiresInSeconds: 3600}, nil }) - client, err := New("id", cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } @@ -626,14 +647,83 @@ func TestTokenProviderOptions(t *testing.T) { } } +// testCache is a simple in-memory cache.ExportReplace implementation +type testCache map[string][]byte + +func (c testCache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { + if v, err := m.Marshal(); err == nil { + c[h.PartitionKey] = v + } + return nil +} + +func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { + if v, has := c[h.PartitionKey]; has { + _ = u.Unmarshal(v) + } + return nil +} + +func TestWithCache(t *testing.T) { + cache := make(testCache) + accessToken := "*" + lmo := "login.microsoftonline.com" + tenantA, tenantB := "a", "b" + authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB) + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600))) + + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + client, err := New(authorityA, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + // The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test + ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } + account := ar.Account + if actual := account.Realm; actual != tenantA { + t.Fatalf(`unexpected realm "%s"`, actual) + } + + // a client configured for a different tenant should be able to authenticate silently with the shared cache's data + client, err = New(authorityB, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + // this should succeed because the cache contains an access token from tenantA + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA))) + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA)) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } + // this should fail because the cache doesn't contain an access token from tenantB + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) + if err == nil { + t.Fatal("expected an error because the cache doesn't have an appropriate access token") + } +} + func TestWithClaims(t *testing.T) { - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } accessToken := "at" lmo, tenant := "login.microsoftonline.com", "tenant" - authority := fmt.Sprintf("https://%s/%s", lmo, tenant) + authority := fmt.Sprintf(authorityFmt, lmo, tenant) for _, test := range []struct { capabilities []string claims, expected string @@ -701,7 +791,7 @@ func TestWithClaims(t *testing.T) { validate(t, r.Form) }), ) - client, err := New(fakeClientID, cred, WithAuthority(authority), WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient)) + client, err := New(authority, fakeClientID, cred, WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -715,7 +805,7 @@ func TestWithClaims(t *testing.T) { ar, err = client.AcquireTokenByAuthCode(ctx, "code", localhost, tokenScope, WithClaims(test.claims)) case "authcodeURL": u := "" - if u, err = client.AuthCodeURL(ctx, "client-id", localhost, tokenScope, WithClaims(test.claims)); err == nil { + if u, err = client.AuthCodeURL(ctx, fakeClientID, localhost, tokenScope, WithClaims(test.claims)); err == nil { var parsed *url.URL if parsed, err = url.Parse(u); err == nil { validate(t, parsed.Query()) @@ -800,7 +890,7 @@ func TestWithTenantID(t *testing.T) { } { for _, method := range []string{"authcode", "authcodeURL", "credential", "obo"} { t.Run(method, func(t *testing.T) { - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -818,7 +908,7 @@ func TestWithTenantID(t *testing.T) { mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) - client, err := New(fakeClientID, cred, WithAuthority(test.authority), WithHTTPClient(&mockClient)) + client, err := New(test.authority, fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -831,7 +921,7 @@ func TestWithTenantID(t *testing.T) { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(test.tenant)) case "authcodeURL": - URL, err = client.AuthCodeURL(ctx, "client-id", localhost, tokenScope, WithTenantID(test.tenant)) + URL, err = client.AuthCodeURL(ctx, fakeClientID, localhost, tokenScope, WithTenantID(test.tenant)) case "credential": ar, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(test.tenant)) case "obo": @@ -882,6 +972,65 @@ func TestWithTenantID(t *testing.T) { }) } } + + // if every auth call specifies a different tenant, Client shouldn't send requests to its configured authority + t.Run("enables fake authority", func(t *testing.T) { + host := "host" + defaultTenant := "default" + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + URL := "" + mockClient := mock.Client{} + client, err := New(fmt.Sprintf(authorityFmt, host, defaultTenant), fakeClientID, cred, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + checkForWrongTenant := func(r *http.Request) { + if u := r.URL.String(); strings.Contains(u, defaultTenant) { + t.Fatalf("unexpected request to the default authority: %q", u) + } + } + ctx := context.Background() + for i := 0; i < 3; i++ { + tenant := fmt.Sprint(i) + expected := fmt.Sprintf(authorityFmt, host, tenant) + // TODO: prevent redundant discovery requests https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) + mockClient.AppendResponse( + mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "", "", 3600)), + mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), + ) + if i == 0 { + // TODO: see above (first silent auth rediscovers instance metadata) + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) + } + ar, err := client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(tenant)) + if err != nil { + t.Fatal(err) + } + if !strings.HasPrefix(URL, expected) { + t.Fatalf(`expected "%s", got "%s"`, expected, URL) + } + if ar.AccessToken != accessToken { + t.Fatalf("unexpected access token %q", ar.AccessToken) + } + // silent authentication should now succeed for the given tenant... + if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatal("cached access token should match the one returned by AcquireToken...") + } + // ...but fail for another tenant + otherTenant := "not-" + tenant + if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(otherTenant)); err == nil { + t.Fatal("expected an error") + } + } + }) } func TestWithInstanceDiscovery(t *testing.T) { @@ -895,7 +1044,7 @@ func TestWithInstanceDiscovery(t *testing.T) { for _, method := range []string{"authcode", "credential", "obo"} { t.Run(method, func(t *testing.T) { authority := stackurl + tenant - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -909,7 +1058,7 @@ func TestWithInstanceDiscovery(t *testing.T) { mockClient.AppendResponse( mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), ) - client, err := New(fakeClientID, cred, WithAuthority(authority), WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) + client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) if err != nil { t.Fatal(err) } @@ -956,7 +1105,7 @@ func TestWithPortAuthority(t *testing.T) { host := sl + port tenant := "00000000-0000-0000-0000-000000000000" authority := fmt.Sprintf("https://%s%s/%s", sl, port, tenant) - cred, err := NewCredFromSecret("secret") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } @@ -970,7 +1119,7 @@ func TestWithPortAuthority(t *testing.T) { mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) - client, err := New(fakeClientID, cred, WithAuthority(authority), WithHTTPClient(&mockClient)) + client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { t.Fatal(err) } @@ -999,11 +1148,11 @@ func TestWithPortAuthority(t *testing.T) { func TestWithLoginHint(t *testing.T) { upn := "user@localhost" - cred, err := NewCredFromSecret("...") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } - client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } @@ -1039,11 +1188,11 @@ func TestWithLoginHint(t *testing.T) { func TestWithDomainHint(t *testing.T) { domain := "contoso.com" - cred, err := NewCredFromSecret("...") + cred, err := NewCredFromSecret(fakeSecret) if err != nil { t.Fatal(err) } - client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{})) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) if err != nil { t.Fatal(err) } diff --git a/apps/confidential/examples_test.go b/apps/confidential/examples_test.go index e067e27a..e4c7180b 100644 --- a/apps/confidential/examples_test.go +++ b/apps/confidential/examples_test.go @@ -24,32 +24,9 @@ func ExampleNewCredFromCert_pem() { log.Fatal(err) } - // PEM files can have multiple certs. This is usually for certificate chaining where roots - // sign to leafs. Useful for TLS, not for this use case. - if len(certs) > 1 { - log.Fatal("too many certificates in PEM file") - } - - cred := confidential.NewCredFromCert(certs[0], priv) - fmt.Println(cred) // Simply here so cred is used, otherwise won't compile. -} - -func ExampleNewCredFromCertChain() { - b, err := os.ReadFile("key.pem") + cred, err := confidential.NewCredFromCert(certs, priv) if err != nil { - // TODO: handle error - } - - // CertFromPEM loads certificates and a private key from the PEM content. If - // the content is encrypted, the second argument must be the password. - certs, priv, err := confidential.CertFromPEM(b, "") - if err != nil { - // TODO: handle error - } - - cred, err := confidential.NewCredFromCertChain(certs, priv) - if err != nil { - // TODO: handle error + log.Fatal(err) } - _ = cred + fmt.Println(cred) // Simply here so cred is used, otherwise won't compile. } diff --git a/apps/design/design.md b/apps/design/design.md index 6e1121e4..a3da2ae5 100644 --- a/apps/design/design.md +++ b/apps/design/design.md @@ -5,6 +5,7 @@ Contributors: - Keegan Caruso(Keegan.Caruso@microsoft.com) - Joel Hendrix(jhendrix@microsoft.com) - Santiago Gonzalez(Santiago.Gonzalez@microsoft.com) +- Bogdan Gavril (bogavril@microsoft.com) ## History @@ -140,6 +141,10 @@ Since we have representation from the Go SDK team, we might have them go bridge the current implementation using some of that code so its possible for our users to store the cert in Keyvault. -## Notes: +## Logging -Do we need: AcquireTokenSilent()?? Seems like we could just bake this into other acquire calls automatically???? \ No newline at end of file +For errors, see [error design](../errors/error_design.md). + +This library does not log personal identifiable information (PII). For a definition of PII, see https://www.microsoft.com/en-us/trust-center/privacy/customer-data-definitions. MSAL Go does not log any of the 3 data categories listed there. + +The library may log information related to your organization, such as tenant id, authority, client id etc. as well as information that cannot be tied to a user such as request correlation id, HTTP status codes etc. diff --git a/apps/errors/error_design.md b/apps/errors/error_design.md index 34a699f4..7ef7862f 100644 --- a/apps/errors/error_design.md +++ b/apps/errors/error_design.md @@ -69,7 +69,7 @@ func (e CallErr) Error() string { // Verbose prints a versbose error message with the request or response. func (e CallErr) Verbose() string { - e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need + e.Resp.Request = nil // This brings in a bunch of TLS stuff we don't need e.Resp.TLS = nil // Same return fmt.Sprintf("%s:\nRequest:\n%s\nResponse:\n%s", e.Err, prettyConf.Sprint(e.Req), prettyConf.Sprint(e.Resp)) } diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index ed8715ce..00617abf 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -45,8 +45,12 @@ type partitionedManager interface { type noopCacheAccessor struct{} -func (n noopCacheAccessor) Replace(cache cache.Unmarshaler, key string) {} -func (n noopCacheAccessor) Export(cache cache.Marshaler, key string) {} +func (n noopCacheAccessor) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { + return nil +} +func (n noopCacheAccessor) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { + return nil +} // AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache). type AcquireTokenSilentParameters struct { @@ -279,12 +283,12 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s return baseURL.String(), nil } -func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) { +func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (ar AuthResult, err error) { // when tenant == "", the caller didn't specify a tenant and WithTenant will use the client's configured tenant tenant := silent.TenantID authParams, err := b.AuthParams.WithTenant(tenant) if err != nil { - return AuthResult{}, err + return ar, err } authParams.Scopes = silent.Scopes authParams.HomeAccountID = silent.Account.HomeAccountID @@ -296,37 +300,48 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if authParams.AuthorizationType == authority.ATOnBehalfOf { if s, ok := b.pmanager.(cache.Serializer); ok { suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) + if err != nil { + return ar, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } storageTokenResponse, err = b.pmanager.Read(ctx, authParams) if err != nil { - return AuthResult{}, err + return ar, err } } else { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := authParams.CacheKey(silent.IsAppCache) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) + if err != nil { + return ar, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } authParams.AuthorizationType = authority.ATRefreshToken storageTokenResponse, err = b.manager.Read(ctx, authParams, silent.Account) if err != nil { - return AuthResult{}, err + return ar, err } } // ignore cached access tokens when given claims if silent.Claims == "" { - result, err := AuthResultFromStorage(storageTokenResponse) + ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { - return result, nil + return ar, err } } // redeem a cached refresh token, if available if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() { - return AuthResult{}, errors.New("no token found") + err = errors.New("no token found") + return ar, err } var cc *accesstokens.Credential if silent.RequestType == accesstokens.ATConfidential { @@ -335,10 +350,11 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken) if err != nil { - return AuthResult{}, err + return ar, err } - return b.AuthResultFromToken(ctx, authParams, token, true) + ar, err = b.AuthResultFromToken(ctx, authParams, token, true) + return ar, err } func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) { @@ -401,67 +417,103 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq return ar, err } -func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) { +func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (ar AuthResult, err error) { if !cacheWrite { return NewAuthResult(token, shared.Account{}) } var account shared.Account - var err error if authParams.AuthorizationType == authority.ATOnBehalfOf { if s, ok := b.pmanager.(cache.Serializer); ok { suggestedCacheKey := token.CacheKey(authParams) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) + if err != nil { + return ar, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } account, err = b.pmanager.Write(authParams, token) if err != nil { - return AuthResult{}, err + return ar, err } } else { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := token.CacheKey(authParams) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) + if err != nil { + return ar, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } account, err = b.manager.Write(authParams, token) if err != nil { - return AuthResult{}, err + return ar, err } } - return NewAuthResult(token, account) + ar, err = NewAuthResult(token, account) + return ar, err } -func (b Client) AllAccounts() []shared.Account { +func (b Client) AllAccounts(ctx context.Context) (accts []shared.Account, err error) { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) + if err != nil { + return accts, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } - accounts := b.manager.AllAccounts() - return accounts + accts = b.manager.AllAccounts() + return accts, err } -func (b Client) Account(homeAccountID string) shared.Account { +func (b Client) Account(ctx context.Context, homeAccountID string) (acct shared.Account, err error) { authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and .AuthParams is not a pointer. authParams.AuthorizationType = authority.AccountByID authParams.HomeAccountID = homeAccountID if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) + if err != nil { + return acct, err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } - account := b.manager.Account(homeAccountID) - return account + acct = b.manager.Account(homeAccountID) + return acct, err } // RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account. -func (b Client) RemoveAccount(account shared.Account) { +func (b Client) RemoveAccount(ctx context.Context, account shared.Account) (err error) { if s, ok := b.manager.(cache.Serializer); ok { suggestedCacheKey := b.AuthParams.CacheKey(false) - b.cacheAccessor.Replace(s, suggestedCacheKey) - defer b.cacheAccessor.Export(s, suggestedCacheKey) + err = b.cacheAccessor.Replace(ctx, s, cache.ReplaceHints{PartitionKey: suggestedCacheKey}) + if err != nil { + return err + } + defer func() { + err = b.export(ctx, s, suggestedCacheKey, err) + }() } b.manager.RemoveAccount(account, b.AuthParams.ClientID) + return err +} + +// export helps other methods defer exporting the cache after possibly updating its in-memory content. +// err is the error the calling method will return. If err isn't nil, export returns it without +// exporting the cache. +func (b Client) export(ctx context.Context, marshal cache.Marshaler, key string, err error) error { + if err != nil { + return err + } + return b.cacheAccessor.Export(ctx, marshal, cache.ExportHints{PartitionKey: key}) } diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 33c5c792..7aac8102 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -5,11 +5,13 @@ package base import ( "context" + "errors" "fmt" "reflect" "testing" "time" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" @@ -40,8 +42,8 @@ var ( testScopes = []string{"scope"} ) -func fakeClient(t *testing.T) Client { - client, err := New(fakeClientID, fmt.Sprintf("https://%s/%s", fakeAuthority, fakeTenantID), &oauth.Client{}) +func fakeClient(t *testing.T, opts ...Option) Client { + client, err := New(fakeClientID, fmt.Sprintf("https://%s/%s", fakeAuthority, fakeTenantID), &oauth.Client{}, opts...) if err != nil { t.Fatal(err) } @@ -192,6 +194,86 @@ func TestAcquireTokenSilentGrantedScopes(t *testing.T) { } } +// failCache helps tests inject cache I/O errors +type failCache struct { + exported bool + exportErr, replaceErr error +} + +func (c *failCache) Export(context.Context, cache.Marshaler, cache.ExportHints) error { + c.exported = true + return c.exportErr +} + +func (c failCache) Replace(context.Context, cache.Unmarshaler, cache.ReplaceHints) error { + return c.replaceErr +} + +func TestCacheIOErrors(t *testing.T) { + ctx := context.Background() + expected := errors.New("cache error") + for _, export := range []bool{true, false} { + name := "replace" + cache := failCache{} + if export { + cache.exportErr = expected + name = "export" + } else { + cache.replaceErr = expected + } + t.Run(name, func(t *testing.T) { + client := fakeClient(t, WithCacheAccessor(&cache)) + _, actual := client.Account(ctx, "...") + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential}) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}}) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{}) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AllAccounts(ctx) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + _, actual = client.AuthResultFromToken(ctx, authority.AuthParams{}, accesstokens.TokenResponse{}, true) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + actual = client.RemoveAccount(ctx, shared.Account{}) + if !errors.Is(actual, expected) { + t.Fatalf(`expected "%v", got "%v"`, expected, actual) + } + }) + } + + // when the client fails to acquire a token, it should return an error instead of exporting the cache + t.Run("auth error", func(t *testing.T) { + cache := failCache{} + client := fakeClient(t, WithCacheAccessor(&cache)) + client.Token.AccessTokens.(*fake.AccessTokens).Err = true + _, err := client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential}) + if err == nil || cache.exported { + t.Fatal("client should have returned an error instead of exporting the cache") + } + _, err = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}}) + if err == nil || cache.exported { + t.Fatal("client should have returned an error instead of exporting the cache") + } + _, err = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{}) + if err == nil || cache.exported { + t.Fatal("client should have returned an error instead of exporting the cache") + } + }) +} + func TestCreateAuthenticationResult(t *testing.T) { future := time.Now().Add(400 * time.Second) diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index f9108235..5f136933 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -171,7 +171,7 @@ func (t *Client) UsernamePassword(ctx context.Context, authParams authority.Auth userRealm, err := t.Authority.UserRealm(ctx, authParams) if err != nil { - return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm(user: %s) from authority: %w", authParams.Username, err) + return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm from authority: %w", err) } switch userRealm.AccountType { @@ -212,7 +212,6 @@ func (d DeviceCode) Token(ctx context.Context) (accesstokens.TokenResponse, erro } var cancel context.CancelFunc - d.Result.ExpiresOn.Sub(time.Now().UTC()) if deadline, ok := ctx.Deadline(); !ok || d.Result.ExpiresOn.Before(deadline) { ctx, cancel = context.WithDeadline(ctx, d.Result.ExpiresOn) } else { diff --git a/apps/internal/oauth/ops/authority/authority.go b/apps/internal/oauth/ops/authority/authority.go index de5f053f..5bebdb8e 100644 --- a/apps/internal/oauth/ops/authority/authority.go +++ b/apps/internal/oauth/ops/authority/authority.go @@ -309,31 +309,24 @@ func firstPathSegment(u *url.URL) (string, error) { return pathParts[1], nil } - return "", errors.New("authority does not have two segments") + return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) } // NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided. -func NewInfoFromAuthorityURI(authorityURI string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) { - authorityURI = strings.ToLower(authorityURI) - var authorityType string - u, err := url.Parse(authorityURI) - if err != nil { - return Info{}, fmt.Errorf("authorityURI passed could not be parsed: %w", err) - } - if u.Scheme != "https" { - return Info{}, fmt.Errorf("authorityURI(%s) must have scheme https", authorityURI) +func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) { + u, err := url.Parse(strings.ToLower(authority)) + if err != nil || u.Scheme != "https" { + return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) } tenant, err := firstPathSegment(u) - if tenant == "adfs" { - authorityType = ADFS - } else { - authorityType = AAD - } - if err != nil { return Info{}, err } + authorityType := AAD + if tenant == "adfs" { + authorityType = ADFS + } // u.Host includes the port, if any, which is required for private cloud deployments return Info{ diff --git a/apps/internal/version/version.go b/apps/internal/version/version.go index c3651e63..4a20fef3 100644 --- a/apps/internal/version/version.go +++ b/apps/internal/version/version.go @@ -5,4 +5,4 @@ package version // Version is the version of this client package that is communicated to the server. -const Version = "0.8.1" +const Version = "0.9.0" diff --git a/apps/public/public.go b/apps/public/public.go index 0a3ffaff..cce05277 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -47,27 +47,17 @@ type AuthResult = base.AuthResult type Account = shared.Account -// Options configures the Client's behavior. -type Options struct { - // Accessor controls cache persistence. By default there is no cache persistence. - // This can be set with the WithCache() option. - Accessor cache.ExportReplace - - // The host of the Azure Active Directory authority. The default is https://login.microsoftonline.com/common. - // This can be changed with the WithAuthority() option. - Authority string - - // The HTTP client used for making requests. - // It defaults to a shared http.Client. - HTTPClient ops.HTTPClient - - capabilities []string - +// clientOptions configures the Client's behavior. +type clientOptions struct { + accessor cache.ExportReplace + authority string + capabilities []string disableInstanceDiscovery bool + httpClient ops.HTTPClient } -func (p *Options) validate() error { - u, err := url.Parse(p.Authority) +func (p *clientOptions) validate() error { + u, err := url.Parse(p.authority) if err != nil { return fmt.Errorf("Authority options cannot be URL parsed: %w", err) } @@ -78,25 +68,25 @@ func (p *Options) validate() error { } // Option is an optional argument to the New constructor. -type Option func(o *Options) +type Option func(o *clientOptions) // WithAuthority allows for a custom authority to be set. This must be a valid https url. func WithAuthority(authority string) Option { - return func(o *Options) { - o.Authority = authority + return func(o *clientOptions) { + o.authority = authority } } -// WithCache allows you to set some type of cache for storing authentication tokens. +// WithCache provides an accessor that will read and write authentication data to an externally managed cache. func WithCache(accessor cache.ExportReplace) Option { - return func(o *Options) { - o.Accessor = accessor + return func(o *clientOptions) { + o.accessor = accessor } } // WithClientCapabilities allows configuring one or more client capabilities such as "CP1" func WithClientCapabilities(capabilities []string) Option { - return func(o *Options) { + return func(o *clientOptions) { // there's no danger of sharing the slice's underlying memory with the application because // this slice is simply passed to base.WithClientCapabilities, which copies its data o.capabilities = capabilities @@ -105,14 +95,14 @@ func WithClientCapabilities(capabilities []string) Option { // WithHTTPClient allows for a custom HTTP client to be set. func WithHTTPClient(httpClient ops.HTTPClient) Option { - return func(o *Options) { - o.HTTPClient = httpClient + return func(o *clientOptions) { + o.httpClient = httpClient } } // WithInstanceDiscovery set to false to disable authority validation (to support private cloud scenarios) func WithInstanceDiscovery(enabled bool) Option { - return func(o *Options) { + return func(o *clientOptions) { o.disableInstanceDiscovery = !enabled } } @@ -125,9 +115,9 @@ type Client struct { // New is the constructor for Client. func New(clientID string, options ...Option) (Client, error) { - opts := Options{ - Authority: base.AuthorityPublicCloud, - HTTPClient: shared.DefaultClient, + opts := clientOptions{ + authority: base.AuthorityPublicCloud, + httpClient: shared.DefaultClient, } for _, o := range options { @@ -137,28 +127,28 @@ func New(clientID string, options ...Option) (Client, error) { return Client{}, err } - base, err := base.New(clientID, opts.Authority, oauth.New(opts.HTTPClient), base.WithCacheAccessor(opts.Accessor), base.WithClientCapabilities(opts.capabilities), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery)) + base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery)) if err != nil { return Client{}, err } return Client{base}, nil } -// createAuthCodeURLOptions contains options for CreateAuthCodeURL -type createAuthCodeURLOptions struct { +// authCodeURLOptions contains options for AuthCodeURL +type authCodeURLOptions struct { claims, loginHint, tenantID, domainHint string } -// CreateAuthCodeURLOption is implemented by options for CreateAuthCodeURL -type CreateAuthCodeURLOption interface { - createAuthCodeURLOption() +// AuthCodeURLOption is implemented by options for AuthCodeURL +type AuthCodeURLOption interface { + authCodeURLOption() } -// CreateAuthCodeURL creates a URL used to acquire an authorization code. +// AuthCodeURL creates a URL used to acquire an authorization code. // // Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] -func (pca Client) CreateAuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...CreateAuthCodeURLOption) (string, error) { - o := createAuthCodeURLOptions{} +func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { + o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return "", err } @@ -181,7 +171,7 @@ func WithClaims(claims string) interface { AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption } { return struct { @@ -190,23 +180,23 @@ func WithClaims(claims string) interface { AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: + case *acquireTokenByAuthCodeOptions: t.claims = claims case *acquireTokenByDeviceCodeOptions: t.claims = claims case *acquireTokenByUsernamePasswordOptions: t.claims = claims - case *AcquireTokenSilentOptions: + case *acquireTokenSilentOptions: t.claims = claims - case *createAuthCodeURLOptions: + case *authCodeURLOptions: t.claims = claims - case *InteractiveAuthOptions: + case *interactiveAuthOptions: t.claims = claims default: return fmt.Errorf("unexpected options type %T", a) @@ -225,7 +215,7 @@ func WithTenantID(tenantID string) interface { AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption } { return struct { @@ -234,23 +224,23 @@ func WithTenantID(tenantID string) interface { AcquireByUsernamePasswordOption AcquireInteractiveOption AcquireSilentOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: + case *acquireTokenByAuthCodeOptions: t.tenantID = tenantID case *acquireTokenByDeviceCodeOptions: t.tenantID = tenantID case *acquireTokenByUsernamePasswordOptions: t.tenantID = tenantID - case *AcquireTokenSilentOptions: + case *acquireTokenSilentOptions: t.tenantID = tenantID - case *createAuthCodeURLOptions: + case *authCodeURLOptions: t.tenantID = tenantID - case *InteractiveAuthOptions: + case *interactiveAuthOptions: t.tenantID = tenantID default: return fmt.Errorf("unexpected options type %T", a) @@ -261,12 +251,10 @@ func WithTenantID(tenantID string) interface { } } -// AcquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. +// acquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call. // These are set by using various AcquireTokenSilentOption functions. -type AcquireTokenSilentOptions struct { - // Account represents the account to use. To set, use the WithSilentAccount() option. - Account Account - +type acquireTokenSilentOptions struct { + account Account claims, tenantID string } @@ -275,11 +263,6 @@ type AcquireSilentOption interface { acquireSilentOption() } -// AcquireTokenSilentOption changes options inside AcquireTokenSilentOptions used in .AcquireTokenSilent(). -type AcquireTokenSilentOption func(a *AcquireTokenSilentOptions) - -func (AcquireTokenSilentOption) acquireSilentOption() {} - // WithSilentAccount uses the passed account during an AcquireTokenSilent() call. func WithSilentAccount(account Account) interface { AcquireSilentOption @@ -292,8 +275,8 @@ func WithSilentAccount(account Account) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenSilentOptions: - t.Account = account + case *acquireTokenSilentOptions: + t.account = account default: return fmt.Errorf("unexpected options type %T", a) } @@ -307,14 +290,14 @@ func WithSilentAccount(account Account) interface { // // Options: [WithClaims], [WithSilentAccount], [WithTenantID] func (pca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts ...AcquireSilentOption) (AuthResult, error) { - o := AcquireTokenSilentOptions{} + o := acquireTokenSilentOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } silentParameters := base.AcquireTokenSilentParameters{ Scopes: scopes, - Account: o.Account, + Account: o.account, Claims: o.claims, RequestType: accesstokens.ATPublic, IsAppCache: false, @@ -420,11 +403,9 @@ func (pca Client) AcquireTokenByDeviceCode(ctx context.Context, scopes []string, return DeviceCode{Result: dc.Result, authParams: authParams, client: pca, dc: dc}, nil } -// AcquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. -type AcquireTokenByAuthCodeOptions struct { - Challenge string - - claims, tenantID string +// acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. +type acquireTokenByAuthCodeOptions struct { + challenge, claims, tenantID string } // AcquireByAuthCodeOption is implemented by options for AcquireTokenByAuthCode @@ -432,11 +413,6 @@ type AcquireByAuthCodeOption interface { acquireByAuthCodeOption() } -// AcquireTokenByAuthCodeOption changes options inside AcquireTokenByAuthCodeOptions used in .AcquireTokenByAuthCode(). -type AcquireTokenByAuthCodeOption func(a *AcquireTokenByAuthCodeOptions) - -func (AcquireTokenByAuthCodeOption) acquireByAuthCodeOption() {} - // WithChallenge allows you to provide a code for the .AcquireTokenByAuthCode() call. func WithChallenge(challenge string) interface { AcquireByAuthCodeOption @@ -449,8 +425,8 @@ func WithChallenge(challenge string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *AcquireTokenByAuthCodeOptions: - t.Challenge = challenge + case *acquireTokenByAuthCodeOptions: + t.challenge = challenge default: return fmt.Errorf("unexpected options type %T", a) } @@ -465,7 +441,7 @@ func WithChallenge(challenge string) interface { // // Options: [WithChallenge], [WithClaims], [WithTenantID] func (pca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, opts ...AcquireByAuthCodeOption) (AuthResult, error) { - o := AcquireTokenByAuthCodeOptions{} + o := acquireTokenByAuthCodeOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } @@ -473,7 +449,7 @@ func (pca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redir params := base.AcquireTokenAuthCodeParameters{ Scopes: scopes, Code: code, - Challenge: o.Challenge, + Challenge: o.challenge, Claims: o.claims, AppType: accesstokens.ATPublic, RedirectURI: redirectURI, @@ -485,23 +461,18 @@ func (pca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redir // Accounts gets all the accounts in the token cache. // If there are no accounts in the cache the returned slice is empty. -func (pca Client) Accounts() []Account { - return pca.base.AllAccounts() +func (pca Client) Accounts(ctx context.Context) ([]Account, error) { + return pca.base.AllAccounts(ctx) } // RemoveAccount signs the account out and forgets account from token cache. -func (pca Client) RemoveAccount(account Account) error { - pca.base.RemoveAccount(account) - return nil +func (pca Client) RemoveAccount(ctx context.Context, account Account) error { + return pca.base.RemoveAccount(ctx, account) } -// InteractiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. -type InteractiveAuthOptions struct { - // Used to specify a custom port for the local server. http://localhost:portnumber - // All other URI components are ignored. - RedirectURI string - - claims, loginHint, tenantID, domainHint string +// interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. +type interactiveAuthOptions struct { + claims, domainHint, loginHint, redirectURI, tenantID string } // AcquireInteractiveOption is implemented by options for AcquireTokenInteractive @@ -509,28 +480,23 @@ type AcquireInteractiveOption interface { acquireInteractiveOption() } -// InteractiveAuthOption changes options inside InteractiveAuthOptions used in .AcquireTokenInteractive(). -type InteractiveAuthOption func(*InteractiveAuthOptions) - -func (InteractiveAuthOption) acquireInteractiveOption() {} - // WithLoginHint pre-populates the login prompt with a username. func WithLoginHint(username string) interface { AcquireInteractiveOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption } { return struct { AcquireInteractiveOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *createAuthCodeURLOptions: + case *authCodeURLOptions: t.loginHint = username - case *InteractiveAuthOptions: + case *interactiveAuthOptions: t.loginHint = username default: return fmt.Errorf("unexpected options type %T", a) @@ -544,20 +510,20 @@ func WithLoginHint(username string) interface { // WithDomainHint adds the IdP domain as domain_hint query parameter in the auth url. func WithDomainHint(domain string) interface { AcquireInteractiveOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption } { return struct { AcquireInteractiveOption - CreateAuthCodeURLOption + AuthCodeURLOption options.CallOption }{ CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *createAuthCodeURLOptions: + case *authCodeURLOptions: t.domainHint = domain - case *InteractiveAuthOptions: + case *interactiveAuthOptions: t.domainHint = domain default: return fmt.Errorf("unexpected options type %T", a) @@ -568,7 +534,8 @@ func WithDomainHint(domain string) interface { } } -// WithRedirectURI uses the specified redirect URI for interactive auth. +// WithRedirectURI sets a port for the local server used in interactive authentication, for +// example http://localhost:port. All URI components other than the port are ignored. func WithRedirectURI(redirectURI string) interface { AcquireInteractiveOption options.CallOption @@ -580,8 +547,8 @@ func WithRedirectURI(redirectURI string) interface { CallOption: options.NewCallOption( func(a any) error { switch t := a.(type) { - case *InteractiveAuthOptions: - t.RedirectURI = redirectURI + case *interactiveAuthOptions: + t.redirectURI = redirectURI default: return fmt.Errorf("unexpected options type %T", a) } @@ -596,7 +563,7 @@ func WithRedirectURI(redirectURI string) interface { // // Options: [WithDomainHint], [WithLoginHint], [WithRedirectURI], [WithTenantID] func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, opts ...AcquireInteractiveOption) (AuthResult, error) { - o := InteractiveAuthOptions{} + o := interactiveAuthOptions{} if err := options.ApplyOptions(&o, opts); err != nil { return AuthResult{}, err } @@ -607,8 +574,8 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, return AuthResult{}, err } var redirectURL *url.URL - if o.RedirectURI != "" { - redirectURL, err = url.Parse(o.RedirectURI) + if o.redirectURI != "" { + redirectURL, err = url.Parse(o.redirectURI) if err != nil { return AuthResult{}, err } diff --git a/apps/public/public_test.go b/apps/public/public_test.go index 9cf2d0ac..8f245900 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -146,8 +146,12 @@ func TestAcquireTokenSilentWithTenantID(t *testing.T) { // cache should return the correct access token for each tenant var account Account - if accounts := client.Accounts(); len(accounts) == 2 { - // expecting one account for each tenant we authenticated in above + accounts, err := client.Accounts(ctx) + if err != nil { + t.Fatal(err) + } + // expecting one account for each tenant we authenticated in above + if len(accounts) == 2 { account = accounts[0] } else { t.Fatalf("expected 2 accounts but got %d", len(accounts)) @@ -233,7 +237,7 @@ func TestAcquireTokenWithTenantID(t *testing.T) { case "authcode": ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithTenantID(test.tenant)) case "authcodeURL": - URL, err = client.CreateAuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithTenantID(test.tenant)) + URL, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithTenantID(test.tenant)) case "devicecode": dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope, WithTenantID(test.tenant)) case "interactive": @@ -359,24 +363,25 @@ func TestWithInstanceDiscovery(t *testing.T) { } // testCache is a simple in-memory cache.ExportReplace implementation -type testCache struct { - store map[string][]byte -} +type testCache map[string][]byte -func (c *testCache) Export(m cache.Marshaler, key string) { - if v, err := m.Marshal(); err == nil { - c.store[key] = v +func (c testCache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error { + v, err := m.Marshal() + if err == nil { + c[h.PartitionKey] = v } + return err } -func (c *testCache) Replace(u cache.Unmarshaler, key string) { - if v, has := c.store[key]; has { - _ = u.Unmarshal(v) +func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error { + if v, has := c[h.PartitionKey]; has { + return u.Unmarshal(v) } + return nil } func TestWithCache(t *testing.T) { - cache := testCache{make(map[string][]byte)} + cache := make(testCache) accessToken, refreshToken := "*", "rt" clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) lmo := "login.microsoftonline.com" @@ -408,7 +413,10 @@ func TestWithCache(t *testing.T) { if err != nil { t.Fatal(err) } - accounts := client.Accounts() + accounts, err := client.Accounts(context.Background()) + if err != nil { + t.Fatal(err) + } if actual := len(accounts); actual != 1 { t.Fatalf("expected 1 account but cache contains %d", actual) } @@ -536,7 +544,7 @@ func TestWithClaims(t *testing.T) { ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithClaims(test.claims)) case "authcodeURL": u := "" - if u, err = client.CreateAuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithClaims(test.claims)); err == nil { + if u, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithClaims(test.claims)); err == nil { var parsed *url.URL if parsed, err = url.Parse(u); err == nil { validate(t, parsed.Query()) @@ -693,7 +701,7 @@ func TestWithLoginHint(t *testing.T) { return fakeBrowserOpenURL(authURL) } acquireOpts := []AcquireInteractiveOption{} - urlOpts := []CreateAuthCodeURLOption{} + urlOpts := []AuthCodeURLOption{} if expectHint { acquireOpts = append(acquireOpts, WithLoginHint(upn)) urlOpts = append(urlOpts, WithLoginHint(upn)) @@ -705,7 +713,7 @@ func TestWithLoginHint(t *testing.T) { if !called { t.Fatal("browserOpenURL wasn't called") } - u, err := client.CreateAuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) + u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) if err == nil { var parsed *url.URL parsed, err = url.Parse(u) @@ -767,7 +775,7 @@ func TestWithDomainHint(t *testing.T) { return fakeBrowserOpenURL(authURL) } var acquireOpts []AcquireInteractiveOption - var urlOpts []CreateAuthCodeURLOption + var urlOpts []AuthCodeURLOption if expectHint { acquireOpts = append(acquireOpts, WithDomainHint(domain)) urlOpts = append(urlOpts, WithDomainHint(domain)) @@ -779,7 +787,7 @@ func TestWithDomainHint(t *testing.T) { if !called { t.Fatal("browserOpenURL wasn't called") } - u, err := client.CreateAuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) + u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...) if err == nil { var parsed *url.URL parsed, err = url.Parse(u) diff --git a/apps/tests/devapps/authorization_code_sample.go b/apps/tests/devapps/authorization_code_sample.go index 46c4a269..dd3e792f 100644 --- a/apps/tests/devapps/authorization_code_sample.go +++ b/apps/tests/devapps/authorization_code_sample.go @@ -51,7 +51,7 @@ func redirectToURL(w http.ResponseWriter, r *http.Request) { authCodeURLParams := msal.CreateAuthorizationCodeURLParameters(config.ClientID, config.RedirectURI, config.Scopes) authCodeURLParams.CodeChallenge = config.CodeChallenge authCodeURLParams.State = config.State - authURL, err := publicClientApp.CreateAuthCodeURL(context.Background(), authCodeURLParams) + authURL, err := publicClientApp.AuthCodeURL(context.Background(), authCodeURLParams) if err != nil { log.Fatal(err) } diff --git a/apps/tests/devapps/client_certificate_sample.go b/apps/tests/devapps/client_certificate_sample.go index 919efb37..11f3acef 100644 --- a/apps/tests/devapps/client_certificate_sample.go +++ b/apps/tests/devapps/client_certificate_sample.go @@ -26,18 +26,11 @@ func acquireTokenClientCertificate() { if err != nil { log.Fatal(err) } - - // PEM files can have multiple certs. This is usually for certificate chaining where roots - // sign to leafs. Useful for TLS, not for this use case. - if len(certs) > 1 { - log.Fatal("too many certificates in PEM file") - } - - cred := confidential.NewCredFromCert(certs[0], privateKey) + cred, err := confidential.NewCredFromCert(certs, privateKey) if err != nil { log.Fatal(err) } - app, err := confidential.New(config.ClientID, cred, confidential.WithAuthority(config.Authority), confidential.WithAccessor(cacheAccessor)) + app, err := confidential.New(config.Authority, config.ClientID, cred, confidential.WithCache(cacheAccessor)) if err != nil { log.Fatal(err) } diff --git a/apps/tests/devapps/client_secret_sample.go b/apps/tests/devapps/client_secret_sample.go index daa6d60a..28918486 100644 --- a/apps/tests/devapps/client_secret_sample.go +++ b/apps/tests/devapps/client_secret_sample.go @@ -17,7 +17,7 @@ func acquireTokenClientSecret() { if err != nil { log.Fatal(err) } - app, err := confidential.New(config.ClientID, cred, confidential.WithAuthority(config.Authority), confidential.WithAccessor(cacheAccessor)) + app, err := confidential.New(config.Authority, config.ClientID, cred, confidential.WithCache(cacheAccessor)) if err != nil { log.Fatal(err) } diff --git a/apps/tests/devapps/confidential_auth_code_sample.go b/apps/tests/devapps/confidential_auth_code_sample.go index 35f31974..836f6e7a 100644 --- a/apps/tests/devapps/confidential_auth_code_sample.go +++ b/apps/tests/devapps/confidential_auth_code_sample.go @@ -17,7 +17,7 @@ func redirectToURLConfidential(w http.ResponseWriter, r *http.Request) { // Getting the URL to redirect to acquire the authorization code authCodeURLParams.CodeChallenge = confidentialConfig.CodeChallenge authCodeURLParams.State = confidentialConfig.State - authURL, err := app.CreateAuthCodeURL(context.Background(), confidentialConfig.ClientID, confidentialConfig.RedirectURI, confidentialConfig.Scopes) + authURL, err := app.AuthCodeURL(context.Background(), confidentialConfig.ClientID, confidentialConfig.RedirectURI, confidentialConfig.Scopes) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return @@ -58,7 +58,7 @@ func getTokenConfidential(w http.ResponseWriter, r *http.Request) { // TODO(msal): Needs to use an x509 certificate like the other now that we are not using a // thumbprint directly. /* -func acquireByAuthorizationCodeConfidential() { +func acquireByAuthorizationCodeConfidential(ctx context.Context) { key, err := os.ReadFile(confidentialConfig.KeyFile) if err != nil { log.Fatal(err) @@ -77,7 +77,7 @@ func acquireByAuthorizationCodeConfidential() { log.Fatal(err) } var userAccount shared.Account - for _, account := range app.Accounts() { + for _, account := range app.Accounts(ctx) { if account.PreferredUsername == confidentialConfig.Username { userAccount = account } diff --git a/apps/tests/devapps/device_code_flow_sample.go b/apps/tests/devapps/device_code_flow_sample.go index bd4d694f..192189cc 100644 --- a/apps/tests/devapps/device_code_flow_sample.go +++ b/apps/tests/devapps/device_code_flow_sample.go @@ -21,7 +21,10 @@ func acquireTokenDeviceCode() { // look in the cache to see if the account to use has been cached var userAccount public.Account - accounts := app.Accounts() + accounts, err := app.Accounts(context.Background()) + if err != nil { + panic("failed to read the cache") + } for _, account := range accounts { if account.PreferredUsername == config.Username { userAccount = account diff --git a/apps/tests/devapps/main.go b/apps/tests/devapps/main.go index ff44dcba..6c1ad04e 100644 --- a/apps/tests/devapps/main.go +++ b/apps/tests/devapps/main.go @@ -1,15 +1,18 @@ package main import ( + "context" "os" ) var ( //config = CreateConfig("config.json") - cacheAccessor = &TokenCache{"serialized_cache.json"} + cacheAccessor = &TokenCache{file: "serialized_cache.json"} ) func main() { + ctx := context.Background() + // TODO(msal): This is pretty yikes. At least we should use the flag package. exampleType := os.Args[1] if exampleType == "1" { @@ -18,7 +21,7 @@ func main() { acquireByAuthorizationCodePublic() */ } else if exampleType == "3" { - acquireByUsernamePasswordPublic() + acquireByUsernamePasswordPublic(ctx) } else if exampleType == "4" { panic("currently not implemented") //acquireByAuthorizationCodeConfidential() diff --git a/apps/tests/devapps/sample_cache_accessor.go b/apps/tests/devapps/sample_cache_accessor.go index 07114a0c..8c52a6c3 100644 --- a/apps/tests/devapps/sample_cache_accessor.go +++ b/apps/tests/devapps/sample_cache_accessor.go @@ -4,6 +4,7 @@ package main import ( + "context" "log" "os" @@ -14,24 +15,18 @@ type TokenCache struct { file string } -func (t *TokenCache) Replace(cache cache.Unmarshaler, key string) { +func (t *TokenCache) Replace(ctx context.Context, cache cache.Unmarshaler, hints cache.ReplaceHints) error { data, err := os.ReadFile(t.file) if err != nil { log.Println(err) } - err = cache.Unmarshal(data) - if err != nil { - log.Println(err) - } + return cache.Unmarshal(data) } -func (t *TokenCache) Export(cache cache.Marshaler, key string) { +func (t *TokenCache) Export(ctx context.Context, cache cache.Marshaler, hints cache.ExportHints) error { data, err := cache.Marshal() if err != nil { log.Println(err) } - err = os.WriteFile(t.file, data, 0600) - if err != nil { - log.Println(err) - } + return os.WriteFile(t.file, data, 0600) } diff --git a/apps/tests/devapps/username_password_sample.go b/apps/tests/devapps/username_password_sample.go index 4eb5c484..c5a5b709 100644 --- a/apps/tests/devapps/username_password_sample.go +++ b/apps/tests/devapps/username_password_sample.go @@ -11,7 +11,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" ) -func acquireByUsernamePasswordPublic() { +func acquireByUsernamePasswordPublic(ctx context.Context) { config := CreateConfig("config.json") app, err := public.New(config.ClientID, public.WithCache(cacheAccessor), public.WithAuthority(config.Authority)) if err != nil { @@ -20,7 +20,10 @@ func acquireByUsernamePasswordPublic() { // look in the cache to see if the account to use has been cached var userAccount public.Account - accounts := app.Accounts() + accounts, err := app.Accounts(ctx) + if err != nil { + panic("failed to read the cache") + } for _, account := range accounts { if account.PreferredUsername == config.Username { userAccount = account diff --git a/apps/tests/integration/integration_test.go b/apps/tests/integration/integration_test.go index 41b4b780..55cbf1ef 100644 --- a/apps/tests/integration/integration_test.go +++ b/apps/tests/integration/integration_test.go @@ -97,7 +97,7 @@ func newLabClient() (*labClient, error) { return nil, fmt.Errorf("could not create a cred from a secret: %w", err) } - app, err := confidential.New(clientID, cred, confidential.WithAuthority(microsoftAuthority)) + app, err := confidential.New(microsoftAuthority, clientID, cred) if err != nil { return nil, err } @@ -227,7 +227,7 @@ func TestConfidentialClientwithSecret(t *testing.T) { panic(errors.Verbose(err)) } - app, err := confidential.New(clientID, cred, confidential.WithAuthority(microsoftAuthority)) + app, err := confidential.New(microsoftAuthority, clientID, cred) if err != nil { panic(errors.Verbose(err)) } @@ -290,7 +290,7 @@ func TestOnBehalfOf(t *testing.T) { if err != nil { panic(errors.Verbose(err)) } - cca, err := confidential.New(ccaClientID, cred) + cca, err := confidential.New("https://login.microsoftonline.com/common", ccaClientID, cred) if err != nil { panic(errors.Verbose(err)) } @@ -390,12 +390,15 @@ func TestRemoveAccount(t *testing.T) { if err != nil { t.Fatalf("TestRemoveAccount: on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", errors.Verbose(err)) } - accounts := app.Accounts() + accounts, err := app.Accounts(ctx) + if err != nil { + t.Fatal(err) + } if len(accounts) == 0 { t.Fatal("TestRemoveAccount: No user accounts found in cache") } testAccount := accounts[0] // Only one account is populated and that is what we will remove. - err = app.RemoveAccount(testAccount) + err = app.RemoveAccount(ctx, testAccount) if err != nil { t.Fatalf("TestRemoveAccount: on RemoveAccount(): got err == %s, want err == nil", errors.Verbose(err)) } diff --git a/go.mod b/go.mod index c0067a21..8bb370bd 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module github.com/AzureAD/microsoft-authentication-library-for-go go 1.18 require ( - github.com/golang-jwt/jwt/v4 v4.4.2 - github.com/google/uuid v1.1.1 + github.com/golang-jwt/jwt/v4 v4.4.3 + github.com/google/uuid v1.3.0 github.com/kylelemons/godebug v1.1.0 - github.com/montanaflynn/stats v0.6.6 - github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 + github.com/montanaflynn/stats v0.7.0 + github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 ) -require golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c // indirect +require golang.org/x/sys v0.5.0 // indirect diff --git a/go.sum b/go.sum index ac58e5e3..26fce0bf 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,13 @@ -github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs= -github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= -github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/montanaflynn/stats v0.6.6 h1:Duep6KMIDpY4Yo11iFsvyqJDyfzLF9+sndUKT+v64GQ= -github.com/montanaflynn/stats v0.6.6/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= -github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4 h1:Qj1ukM4GlMWXNdMBuXcXfz/Kw9s1qm0CLY32QxuSImI= -github.com/pkg/browser v0.0.0-20210115035449-ce105d075bb4/go.mod h1:N6UoU20jOqggOuDwUaBQpluzLNDqif3kq9z2wpdYEfQ= -golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c h1:VwygUrnw9jn88c4u8GD3rZQbqrP/tgas88tPUbBxQrk= -golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +github.com/montanaflynn/stats v0.7.0 h1:r3y12KyNxj/Sb/iOE46ws+3mS1+MZca1wlHQFPsY/JU= +github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=