diff --git a/autorest/adal/token.go b/autorest/adal/token.go index 7c7fca371..780f514e0 100644 --- a/autorest/adal/token.go +++ b/autorest/adal/token.go @@ -106,6 +106,9 @@ type RefresherWithContext interface { // a successful token refresh type TokenRefreshCallback func(Token) error +// TokenRefresh is a type representing a custom callback to refresh a token +type TokenRefresh func(ctx context.Context, resource string) (*Token, error) + // Token encapsulates the access token used to authorize Azure requests. // https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response type Token struct { @@ -344,10 +347,11 @@ func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, err // ServicePrincipalToken encapsulates a Token created for a Service Principal. type ServicePrincipalToken struct { - inner servicePrincipalToken - refreshLock *sync.RWMutex - sender Sender - refreshCallbacks []TokenRefreshCallback + inner servicePrincipalToken + refreshLock *sync.RWMutex + sender Sender + customRefreshFunc *TokenRefresh + refreshCallbacks []TokenRefreshCallback // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token. MaxMSIRefreshAttempts int } @@ -362,6 +366,11 @@ func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCa spt.refreshCallbacks = callbacks } +// SetCustomRefreshFunc sets a custom refresh function used to refresh the token. +func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc *TokenRefresh) { + spt.customRefreshFunc = customRefreshFunc +} + // MarshalJSON implements the json.Marshaler interface. func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) { return json.Marshal(spt.inner) @@ -833,6 +842,18 @@ func isIMDS(u url.URL) bool { } func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { + if spt.customRefreshFunc != nil { + f := *spt.customRefreshFunc + token, err := f(ctx, resource) + if err != nil { + return err + } + + spt.inner.Token = *token + + return spt.InvokeRefreshCallbacks(*token) + } + req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil) if err != nil { return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) diff --git a/autorest/adal/token_test.go b/autorest/adal/token_test.go index d2cb53f69..736159bc5 100644 --- a/autorest/adal/token_test.go +++ b/autorest/adal/token_test.go @@ -100,6 +100,24 @@ func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) { } } +func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) { + spt := newServicePrincipalToken() + + var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) { + return nil, nil + } + + if spt.customRefreshFunc != nil { + t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't") + } + + spt.SetCustomRefreshFunc(&refreshFunc) + + if spt.customRefreshFunc == nil { + t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func") + } +} + func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) { spt := newServicePrincipalToken() @@ -123,6 +141,26 @@ func TestServicePrincipalTokenSetSender(t *testing.T) { } } +func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) { + spt := newServicePrincipalToken() + + called := false + var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) { + called = true + return &Token{}, nil + } + spt.SetCustomRefreshFunc(&refreshFunc) + if called { + t.Fatalf("adal: ServicePrincipalToken#refreshInternal called the refresh function prior to refreshing") + } + + spt.refreshInternal(context.Background(), "https://example.com") + + if !called { + t.Fatalf("adal: ServicePrincipalToken#refreshInternal didn't call the refresh function") + } +} + func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) { spt := newServicePrincipalToken()