diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 98f29d2e33af..e960660d69a8 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -12,6 +12,7 @@ authentication in a Docker Desktop container ### Other Changes +* `AzureCLICredential` uses the CLI's `expires_on` value for token expiration ## 1.6.0-beta.1 (2024-01-17) diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go index 498c3586bc8f..b9976f5fedee 100644 --- a/sdk/azidentity/azure_cli_credential.go +++ b/sdk/azidentity/azure_cli_credential.go @@ -163,26 +163,21 @@ var defaultAzTokenProvider azTokenProvider = func(ctx context.Context, scopes [] func (c *AzureCLICredential) createAccessToken(tk []byte) (azcore.AccessToken, error) { t := struct { - AccessToken string `json:"accessToken"` - Authority string `json:"_authority"` - ClientID string `json:"_clientId"` - ExpiresOn string `json:"expiresOn"` - IdentityProvider string `json:"identityProvider"` - IsMRRT bool `json:"isMRRT"` - RefreshToken string `json:"refreshToken"` - Resource string `json:"resource"` - TokenType string `json:"tokenType"` - UserID string `json:"userId"` + AccessToken string `json:"accessToken"` + Expires_On int64 `json:"expires_on"` + ExpiresOn string `json:"expiresOn"` }{} err := json.Unmarshal(tk, &t) if err != nil { return azcore.AccessToken{}, err } - // the Azure CLI's "expiresOn" is local time - exp, err := time.ParseInLocation("2006-01-02 15:04:05.999999", t.ExpiresOn, time.Local) - if err != nil { - return azcore.AccessToken{}, fmt.Errorf("Error parsing token expiration time %q: %v", t.ExpiresOn, err) + exp := time.Unix(t.Expires_On, 0) + if t.Expires_On == 0 { + exp, err = time.ParseInLocation("2006-01-02 15:04:05.999999", t.ExpiresOn, time.Local) + if err != nil { + return azcore.AccessToken{}, fmt.Errorf("%s: error parsing token expiration time %q: %v", credNameAzureCLI, t.ExpiresOn, err) + } } converted := azcore.AccessToken{ diff --git a/sdk/azidentity/azure_cli_credential_test.go b/sdk/azidentity/azure_cli_credential_test.go index feae928575ad..82f09735246c 100644 --- a/sdk/azidentity/azure_cli_credential_test.go +++ b/sdk/azidentity/azure_cli_credential_test.go @@ -12,23 +12,37 @@ import ( "fmt" "testing" "time" + + "github.com/stretchr/testify/require" ) -var ( - mockAzTokenProviderSuccess = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) { - return []byte(fmt.Sprintf(`{ - "accessToken": "mocktoken", - "expiresOn": "2001-02-03 04:05:06.000007", - "subscription": %q, +// azTokenOutput returns JSON output similar to az account get-access-token. +// All versions of az return expiresOn, a local timestamp. v2.54.0+ +// additionally return expires_on, a Unix timestamp. If the expires_on +// argument to this function is 0, the returned JSON omits expires_on. +func azTokenOutput(expiresOn string, expires_on int64) []byte { + e_o := "" + if expires_on != 0 { + e_o = fmt.Sprintf(` + "expires_on": %d, +`, expires_on) + } + return []byte(fmt.Sprintf(`{ + "accessToken": %q, + "expiresOn": %q,%s + "subscription": "fake-subscription", "tenant": %q, "tokenType": "Bearer" +}`, tokenValue, expiresOn, e_o, fakeTenantID)) +} + +func mockAzTokenProviderFailure(context.Context, []string, string, string) ([]byte, error) { + return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil) +} + +func mockAzTokenProviderSuccess(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) { + return azTokenOutput("2001-02-03 04:05:06.000007", 0), nil } -`, subscription, tenant)), nil - } - mockAzTokenProviderFailure = func(context.Context, []string, string, string) ([]byte, error) { - return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil) - } -) func TestAzureCLICredential_DefaultChainError(t *testing.T) { cred, err := NewAzureCLICredential(&AzureCLICredentialOptions{ @@ -72,22 +86,34 @@ func TestAzureCLICredential_Error(t *testing.T) { } func TestAzureCLICredential_GetTokenSuccess(t *testing.T) { - options := AzureCLICredentialOptions{} - options.tokenProvider = mockAzTokenProviderSuccess - cred, err := NewAzureCLICredential(&options) - if err != nil { - t.Fatal(err) - } - at, err := cred.GetToken(context.Background(), testTRO) - if err != nil { - t.Fatal(err) - } - if at.Token != "mocktoken" { - t.Fatalf("unexpected access token %q", at.Token) - } - expected := time.Date(2001, 2, 3, 4, 5, 6, 7000, time.Local).UTC() - if actual := at.ExpiresOn; !actual.Equal(expected) || actual.Location() != time.UTC { - t.Fatalf("expected %q, got %q", expected, actual) + expectedExpiresOn := time.Now().Add(time.Hour).UTC().Truncate(time.Second) + for _, withExpires_on := range []bool{false, true} { + name := "without expires_on" + if withExpires_on { + name = "with expires_on" + } + t.Run(name, func(t *testing.T) { + ExpiresOn := expectedExpiresOn.Local().Format("2006-01-02 15:04:05.999999999") + expires_on := int64(0) + if withExpires_on { + // set the wrong time for ExpiresOn so this test fails if the credential uses it + ExpiresOn = "2001-01-01 01:01:01.000000" + expires_on = expectedExpiresOn.Unix() + } + cred, err := NewAzureCLICredential(&AzureCLICredentialOptions{ + tokenProvider: func(context.Context, []string, string, string) ([]byte, error) { + output := azTokenOutput(ExpiresOn, expires_on) + return output, nil + }, + }) + require.NoError(t, err) + + actual, err := cred.GetToken(context.Background(), testTRO) + require.NoError(t, err) + require.True(t, actual.ExpiresOn.Equal(expectedExpiresOn)) + require.Equal(t, time.UTC, actual.ExpiresOn.Location()) + require.Equal(t, tokenValue, actual.Token) + }) } }