Skip to content

Commit

Permalink
token: support for a custom refresh func (#476)
Browse files Browse the repository at this point in the history
* token: support for a custom refresh func

* pass closures by value

* minor clean-up
  • Loading branch information
tombuildsstuff authored and jhendrixMSFT committed Oct 23, 2019
1 parent 5f1f2ad commit 7820109
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
26 changes: 22 additions & 4 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ type RefresherWithContext interface {
// a successful token refresh
type TokenRefreshCallback func(Token) error

// TokenRefresh is a type representing a custom callback to refresh a token
type TokenRefresh func(ctx context.Context, resource string) (*Token, error)

// Token encapsulates the access token used to authorize Azure requests.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
type Token struct {
Expand Down Expand Up @@ -344,10 +347,11 @@ func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, err

// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
inner servicePrincipalToken
refreshLock *sync.RWMutex
sender Sender
refreshCallbacks []TokenRefreshCallback
inner servicePrincipalToken
refreshLock *sync.RWMutex
sender Sender
customRefreshFunc TokenRefresh
refreshCallbacks []TokenRefreshCallback
// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
MaxMSIRefreshAttempts int
}
Expand All @@ -362,6 +366,11 @@ func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCa
spt.refreshCallbacks = callbacks
}

// SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
spt.customRefreshFunc = customRefreshFunc
}

// MarshalJSON implements the json.Marshaler interface.
func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
return json.Marshal(spt.inner)
Expand Down Expand Up @@ -833,6 +842,15 @@ func isIMDS(u url.URL) bool {
}

func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
if spt.customRefreshFunc != nil {
token, err := spt.customRefreshFunc(ctx, resource)
if err != nil {
return err
}
spt.inner.Token = *token
return spt.InvokeRefreshCallbacks(spt.inner.Token)
}

req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
Expand Down
38 changes: 38 additions & 0 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
}
}

func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) {
spt := newServicePrincipalToken()

var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
return nil, nil
}

if spt.customRefreshFunc != nil {
t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't")
}

spt.SetCustomRefreshFunc(refreshFunc)

if spt.customRefreshFunc == nil {
t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func")
}
}

func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
spt := newServicePrincipalToken()

Expand All @@ -123,6 +141,26 @@ func TestServicePrincipalTokenSetSender(t *testing.T) {
}
}

func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) {
spt := newServicePrincipalToken()

called := false
var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
called = true
return &Token{}, nil
}
spt.SetCustomRefreshFunc(refreshFunc)
if called {
t.Fatalf("adal: ServicePrincipalToken#refreshInternal called the refresh function prior to refreshing")
}

spt.refreshInternal(context.Background(), "https://example.com")

if !called {
t.Fatalf("adal: ServicePrincipalToken#refreshInternal didn't call the refresh function")
}
}

func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
spt := newServicePrincipalToken()

Expand Down

0 comments on commit 7820109

Please sign in to comment.