Skip to content

Commit

Permalink
token: support for a custom refresh func
Browse files Browse the repository at this point in the history
  • Loading branch information
tombuildsstuff committed Oct 20, 2019
1 parent 0b055be commit 6f1862b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
29 changes: 25 additions & 4 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ type RefresherWithContext interface {
// a successful token refresh
type TokenRefreshCallback func(Token) error

// TokenRefresh is a type representing a custom callback to refresh a token
type TokenRefresh func(ctx context.Context, resource string) (*Token, error)

// Token encapsulates the access token used to authorize Azure requests.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
type Token struct {
Expand Down Expand Up @@ -344,10 +347,11 @@ func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, err

// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
inner servicePrincipalToken
refreshLock *sync.RWMutex
sender Sender
refreshCallbacks []TokenRefreshCallback
inner servicePrincipalToken
refreshLock *sync.RWMutex
sender Sender
customRefreshFunc *TokenRefresh
refreshCallbacks []TokenRefreshCallback
// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
MaxMSIRefreshAttempts int
}
Expand All @@ -362,6 +366,11 @@ func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCa
spt.refreshCallbacks = callbacks
}

// SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc *TokenRefresh) {
spt.customRefreshFunc = customRefreshFunc
}

// MarshalJSON implements the json.Marshaler interface.
func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
return json.Marshal(spt.inner)
Expand Down Expand Up @@ -833,6 +842,18 @@ func isIMDS(u url.URL) bool {
}

func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
if spt.customRefreshFunc != nil {
f := *spt.customRefreshFunc
token, err := f(ctx, resource)
if err != nil {
return err
}

spt.inner.Token = *token

return spt.InvokeRefreshCallbacks(*token)
}

req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
Expand Down
18 changes: 18 additions & 0 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
}
}

func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) {
spt := newServicePrincipalToken()

var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
return nil, nil
}

if spt.customRefreshFunc != nil {
t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't")
}

spt.SetCustomRefreshFunc(&refreshFunc)

if spt.customRefreshFunc == nil {
t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func")
}
}

func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
spt := newServicePrincipalToken()

Expand Down

0 comments on commit 6f1862b

Please sign in to comment.