Skip to content

Commit

Permalink
Merge pull request #376 from AzureAD/release-0.8.1
Browse files Browse the repository at this point in the history
Release 0.8.1
  • Loading branch information
rayluo committed Jan 25, 2023
2 parents 578cd5d + b86a39e commit 24a6783
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 45 deletions.
83 changes: 65 additions & 18 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
"github.com/kylelemons/godebug/pretty"
)

const localhost = "http://localhost"

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
type errorClient struct{}

Expand Down Expand Up @@ -159,6 +161,51 @@ func TestAcquireTokenByCredential(t *testing.T) {
}
}

func TestAcquireTokenOnBehalfOf(t *testing.T) {
// this test is an offline version of TestOnBehalfOf in integration_test.go
cred, err := NewCredFromSecret("secret")
if err != nil {
t.Fatal(err)
}
lmo := "login.microsoftonline.com"
assertion := "assertion"
mockClient := mock.Client{}
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, "common")))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "common")))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 3600)))

client, err := New("clientID", cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
tk, err := client.AcquireTokenOnBehalfOf(context.Background(), assertion, tokenScope)
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token {
t.Fatalf("wanted %q, got %q", token, tk.AccessToken)
}
// should return the cached access token
tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion, tokenScope)
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token {
t.Fatalf("wanted %q, got %q", token, tk.AccessToken)
}
// new assertion should trigger new token request
token2 := token + "2"
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token2, "", "rt", "", 3600)))
tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion+"2", tokenScope)
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token2 {
t.Fatal("expected a new token")
}
}

func TestAcquireTokenByAssertionCallback(t *testing.T) {
calls := 0
key := struct{}{}
Expand Down Expand Up @@ -289,7 +336,7 @@ func TestAcquireTokenSilentTenants(t *testing.T) {
lmo := "login.microsoftonline.com"
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenants[0])))
client, err := New("client-id", cred, WithHTTPClient(&mockClient))
client, err := New(fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -334,7 +381,7 @@ func TestInvalidCredential(t *testing.T) {
NewCredFromCert(nil, key),
} {
t.Run("", func(t *testing.T) {
_, err := New("client-id", cred)
_, err := New(fakeClientID, cred)
if err == nil {
t.Fatal("expected an error")
}
Expand Down Expand Up @@ -514,7 +561,7 @@ func TestNewCredFromTokenProvider(t *testing.T) {
ExpiresInSeconds: expiresIn,
}, nil
})
client, err := New("client-id", cred, WithHTTPClient(&errorClient{}))
client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -545,7 +592,7 @@ func TestNewCredFromTokenProviderError(t *testing.T) {
cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
return exported.TokenProviderResult{}, errors.New(expectedError)
})
client, err := New("client-id", cred)
client, err := New(fakeClientID, cred)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -654,7 +701,7 @@ func TestWithClaims(t *testing.T) {
validate(t, r.Form)
}),
)
client, err := New("client-id", cred, WithAuthority(authority), WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient))
client, err := New(fakeClientID, cred, WithAuthority(authority), WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
Expand All @@ -665,10 +712,10 @@ func TestWithClaims(t *testing.T) {
var ar AuthResult
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "code", "https://localhost", tokenScope, WithClaims(test.claims))
ar, err = client.AcquireTokenByAuthCode(ctx, "code", localhost, tokenScope, WithClaims(test.claims))
case "authcodeURL":
u := ""
if u, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithClaims(test.claims)); err == nil {
if u, err = client.AuthCodeURL(ctx, "client-id", localhost, tokenScope, WithClaims(test.claims)); err == nil {
var parsed *url.URL
if parsed, err = url.Parse(u); err == nil {
validate(t, parsed.Query())
Expand Down Expand Up @@ -771,7 +818,7 @@ func TestWithTenantID(t *testing.T) {
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
client, err := New("client-id", cred, WithAuthority(test.authority), WithHTTPClient(&mockClient))
client, err := New(fakeClientID, cred, WithAuthority(test.authority), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
Expand All @@ -782,9 +829,9 @@ func TestWithTenantID(t *testing.T) {
var ar AuthResult
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithTenantID(test.tenant))
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(test.tenant))
case "authcodeURL":
URL, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithTenantID(test.tenant))
URL, err = client.AuthCodeURL(ctx, "client-id", localhost, tokenScope, WithTenantID(test.tenant))
case "credential":
ar, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(test.tenant))
case "obo":
Expand Down Expand Up @@ -862,7 +909,7 @@ func TestWithInstanceDiscovery(t *testing.T) {
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
)
client, err := New("client-id", cred, WithAuthority(authority), WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
client, err := New(fakeClientID, cred, WithAuthority(authority), WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
if err != nil {
t.Fatal(err)
}
Expand All @@ -873,7 +920,7 @@ func TestWithInstanceDiscovery(t *testing.T) {
var ar AuthResult
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope)
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope)
case "credential":
ar, err = client.AcquireTokenByCredential(ctx, tokenScope)
case "obo":
Expand Down Expand Up @@ -923,7 +970,7 @@ func TestWithPortAuthority(t *testing.T) {
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
client, err := New("client-id", cred, WithAuthority(authority), WithHTTPClient(&mockClient))
client, err := New(fakeClientID, cred, WithAuthority(authority), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
Expand All @@ -932,7 +979,7 @@ func TestWithPortAuthority(t *testing.T) {
t.Fatal("silent auth should fail because the cache is empty")
}
var ar AuthResult
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope)
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope)
if err != nil {
t.Fatal(err)
}
Expand All @@ -956,7 +1003,7 @@ func TestWithLoginHint(t *testing.T) {
if err != nil {
t.Fatal(err)
}
client, err := New("client-id", cred, WithHTTPClient(&errorClient{}))
client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
Expand All @@ -967,7 +1014,7 @@ func TestWithLoginHint(t *testing.T) {
if expectHint {
opts = append(opts, WithLoginHint(upn))
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, opts...)
u, err := client.AuthCodeURL(context.Background(), "id", localhost, tokenScope, opts...)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -996,7 +1043,7 @@ func TestWithDomainHint(t *testing.T) {
if err != nil {
t.Fatal(err)
}
client, err := New("client-id", cred, WithHTTPClient(&errorClient{}))
client, err := New(fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
Expand All @@ -1007,7 +1054,7 @@ func TestWithDomainHint(t *testing.T) {
if expectHint {
opts = append(opts, WithDomainHint(domain))
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, opts...)
u, err := client.AuthCodeURL(context.Background(), "id", localhost, tokenScope, opts...)
if err != nil {
t.Fatal(err)
}
Expand Down
20 changes: 1 addition & 19 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,26 +280,8 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s
}

func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) {
// when tenant == "", the caller didn't specify a tenant and WithTenant will use the client's configured tenant
tenant := silent.TenantID
if tenant == "" {
// the caller didn't specify a tenant, so we'll use the client's configured tenant or the given account's home tenant
switch tenant = b.AuthParams.AuthorityInfo.Tenant; tenant {
case "common", "organizations":
if _, homeTenant, found := strings.Cut(silent.Account.HomeAccountID, "."); found {
// note that both public and confidential clients allow specifying an account for silent auth
tenant = homeTenant
} else if !b.AuthParams.IsConfidentialClient {
// public client requires the caller to identify a specific user for silent authentication
return AuthResult{}, errors.New("use the WithSilentAccount option to specify an account")
}
// else we have a confidential client and no account specified. We can't return an error here because
// the caller may have configured the client with a custom token provider, in which case the client
// handles caching and the token provider is responsible for everything else, including determining
// the correct tenant.
default:
// use the client's configured tenant
}
}
authParams, err := b.AuthParams.WithTenant(tenant)
if err != nil {
return AuthResult{}, err
Expand Down
1 change: 0 additions & 1 deletion apps/internal/base/internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
target := strings.Join(tokenResponse.GrantedScopes.Slice, scopeSeparator)

cachedAt := time.Now()

var account shared.Account
Expand Down
2 changes: 1 addition & 1 deletion apps/internal/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo string, e

func GetIDToken(tenant, issuer string) string {
now := time.Now().Unix()
payload := []byte(fmt.Sprintf(`{"aud": "%s","exp": %d,"iat": %d,"iss": "%s"}`, tenant, now+3600, now, issuer))
payload := []byte(fmt.Sprintf(`{"aud": "%s","exp": %d,"iat": %d,"iss": "%s","tid": "%s"}`, tenant, now+3600, now, issuer, tenant))
return fmt.Sprintf("header.%s.signature", base64.RawStdEncoding.EncodeToString(payload))
}

Expand Down
2 changes: 1 addition & 1 deletion apps/internal/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
package version

// Version is the version of this client package that is communicated to the server.
const Version = "0.8.0"
const Version = "0.8.1"
47 changes: 42 additions & 5 deletions apps/public/public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"github.com/kylelemons/godebug/pretty"
)

const authorityFmt = "https://%s/%s"

var tokenScope = []string{"the_scope"}

func fakeBrowserOpenURL(authURL string) error {
Expand Down Expand Up @@ -77,12 +79,47 @@ func TestAcquireTokenInteractive(t *testing.T) {
}
}

func TestAcquireTokenSilentTenants(t *testing.T) {
func TestAcquireTokenSilentHomeTenantAliases(t *testing.T) {
accessToken := "*"
homeTenant := "home-tenant"
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(
fmt.Sprintf(`{"uid":"uid","utid":"%s"}`, homeTenant),
))
lmo := "login.microsoftonline.com"
for _, alias := range []string{"common", "organizations"} {
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, alias)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, homeTenant)))
client, err := New("client-id", WithAuthority(fmt.Sprintf(authorityFmt, lmo, alias)), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
// the auth flow isn't important, we just need to populate the cache
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
}
account := ar.Account
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
}
}
}

func TestAcquireTokenSilentWithTenantID(t *testing.T) {
tenantA, tenantB := "a", "b"
lmo := "login.microsoftonline.com"
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA)))
client, err := New("client-id", WithAuthority(fmt.Sprintf("https://%s/%s", lmo, tenantA)), WithHTTPClient(&mockClient))
client, err := New("client-id", WithAuthority(fmt.Sprintf(authorityFmt, lmo, tenantA)), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
Expand All @@ -96,7 +133,7 @@ func TestAcquireTokenSilentTenants(t *testing.T) {
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`)))
mockClient.AppendResponse(mock.WithBody(
mock.GetAccessTokenBody(tenant, mock.GetIDToken(tenant, fmt.Sprintf("https://%s/%s", lmo, tenant)), "rt-"+tenant, clientInfo, 3600)),
mock.GetAccessTokenBody(tenant, mock.GetIDToken(tenant, fmt.Sprintf(authorityFmt, lmo, tenant)), "rt-"+tenant, clientInfo, 3600)),
)
ar, err := client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithTenantID(tenant))
if err != nil {
Expand Down Expand Up @@ -344,7 +381,7 @@ func TestWithCache(t *testing.T) {
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
lmo := "login.microsoftonline.com"
tenantA, tenantB := "a", "b"
authorityA, authorityB := fmt.Sprintf("https://%s/%s", lmo, tenantA), fmt.Sprintf("https://%s/%s", lmo, tenantB)
authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB)
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), refreshToken, clientInfo, 3600)))
Expand Down Expand Up @@ -410,7 +447,7 @@ func TestWithClaims(t *testing.T) {

clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
lmo, tenant := "login.microsoftonline.com", "tenant"
authority := fmt.Sprintf("https://%s/%s", lmo, tenant)
authority := fmt.Sprintf(authorityFmt, lmo, tenant)
accessToken, idToken, refreshToken := "at", mock.GetIDToken(tenant, lmo), "rt"
for _, test := range []struct {
capabilities []string
Expand Down

0 comments on commit 24a6783

Please sign in to comment.