Skip to content

Commit

Permalink
AzureCLICredential prefers expires_on value (#22299)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jan 26, 2024
1 parent 6e01b90 commit aeb9fa4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 42 deletions.
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
authentication in a Docker Desktop container

### Other Changes
* `AzureCLICredential` uses the CLI's `expires_on` value for token expiration

## 1.6.0-beta.1 (2024-01-17)

Expand Down
23 changes: 9 additions & 14 deletions sdk/azidentity/azure_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,26 +163,21 @@ var defaultAzTokenProvider azTokenProvider = func(ctx context.Context, scopes []

func (c *AzureCLICredential) createAccessToken(tk []byte) (azcore.AccessToken, error) {
t := struct {
AccessToken string `json:"accessToken"`
Authority string `json:"_authority"`
ClientID string `json:"_clientId"`
ExpiresOn string `json:"expiresOn"`
IdentityProvider string `json:"identityProvider"`
IsMRRT bool `json:"isMRRT"`
RefreshToken string `json:"refreshToken"`
Resource string `json:"resource"`
TokenType string `json:"tokenType"`
UserID string `json:"userId"`
AccessToken string `json:"accessToken"`
Expires_On int64 `json:"expires_on"`
ExpiresOn string `json:"expiresOn"`
}{}
err := json.Unmarshal(tk, &t)
if err != nil {
return azcore.AccessToken{}, err
}

// the Azure CLI's "expiresOn" is local time
exp, err := time.ParseInLocation("2006-01-02 15:04:05.999999", t.ExpiresOn, time.Local)
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("Error parsing token expiration time %q: %v", t.ExpiresOn, err)
exp := time.Unix(t.Expires_On, 0)
if t.Expires_On == 0 {
exp, err = time.ParseInLocation("2006-01-02 15:04:05.999999", t.ExpiresOn, time.Local)
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("%s: error parsing token expiration time %q: %v", credNameAzureCLI, t.ExpiresOn, err)
}
}

converted := azcore.AccessToken{
Expand Down
82 changes: 54 additions & 28 deletions sdk/azidentity/azure_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,37 @@ import (
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"
)

var (
mockAzTokenProviderSuccess = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
return []byte(fmt.Sprintf(`{
"accessToken": "mocktoken",
"expiresOn": "2001-02-03 04:05:06.000007",
"subscription": %q,
// azTokenOutput returns JSON output similar to az account get-access-token.
// All versions of az return expiresOn, a local timestamp. v2.54.0+
// additionally return expires_on, a Unix timestamp. If the expires_on
// argument to this function is 0, the returned JSON omits expires_on.
func azTokenOutput(expiresOn string, expires_on int64) []byte {
e_o := ""
if expires_on != 0 {
e_o = fmt.Sprintf(`
"expires_on": %d,
`, expires_on)
}
return []byte(fmt.Sprintf(`{
"accessToken": %q,
"expiresOn": %q,%s
"subscription": "fake-subscription",
"tenant": %q,
"tokenType": "Bearer"
}`, tokenValue, expiresOn, e_o, fakeTenantID))
}

func mockAzTokenProviderFailure(context.Context, []string, string, string) ([]byte, error) {
return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil)
}

func mockAzTokenProviderSuccess(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
return azTokenOutput("2001-02-03 04:05:06.000007", 0), nil
}
`, subscription, tenant)), nil
}
mockAzTokenProviderFailure = func(context.Context, []string, string, string) ([]byte, error) {
return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil)
}
)

func TestAzureCLICredential_DefaultChainError(t *testing.T) {
cred, err := NewAzureCLICredential(&AzureCLICredentialOptions{
Expand Down Expand Up @@ -72,22 +86,34 @@ func TestAzureCLICredential_Error(t *testing.T) {
}

func TestAzureCLICredential_GetTokenSuccess(t *testing.T) {
options := AzureCLICredentialOptions{}
options.tokenProvider = mockAzTokenProviderSuccess
cred, err := NewAzureCLICredential(&options)
if err != nil {
t.Fatal(err)
}
at, err := cred.GetToken(context.Background(), testTRO)
if err != nil {
t.Fatal(err)
}
if at.Token != "mocktoken" {
t.Fatalf("unexpected access token %q", at.Token)
}
expected := time.Date(2001, 2, 3, 4, 5, 6, 7000, time.Local).UTC()
if actual := at.ExpiresOn; !actual.Equal(expected) || actual.Location() != time.UTC {
t.Fatalf("expected %q, got %q", expected, actual)
expectedExpiresOn := time.Now().Add(time.Hour).UTC().Truncate(time.Second)
for _, withExpires_on := range []bool{false, true} {
name := "without expires_on"
if withExpires_on {
name = "with expires_on"
}
t.Run(name, func(t *testing.T) {
ExpiresOn := expectedExpiresOn.Local().Format("2006-01-02 15:04:05.999999999")
expires_on := int64(0)
if withExpires_on {
// set the wrong time for ExpiresOn so this test fails if the credential uses it
ExpiresOn = "2001-01-01 01:01:01.000000"
expires_on = expectedExpiresOn.Unix()
}
cred, err := NewAzureCLICredential(&AzureCLICredentialOptions{
tokenProvider: func(context.Context, []string, string, string) ([]byte, error) {
output := azTokenOutput(ExpiresOn, expires_on)
return output, nil
},
})
require.NoError(t, err)

actual, err := cred.GetToken(context.Background(), testTRO)
require.NoError(t, err)
require.True(t, actual.ExpiresOn.Equal(expectedExpiresOn))
require.Equal(t, time.UTC, actual.ExpiresOn.Location())
require.Equal(t, tokenValue, actual.Token)
})
}
}

Expand Down

0 comments on commit aeb9fa4

Please sign in to comment.