Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token: support for a custom refresh func #476

Merged
merged 3 commits into from
Oct 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ type RefresherWithContext interface {
// a successful token refresh
type TokenRefreshCallback func(Token) error

// TokenRefresh is a type representing a custom callback to refresh a token
type TokenRefresh func(ctx context.Context, resource string) (*Token, error)

// Token encapsulates the access token used to authorize Azure requests.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response
type Token struct {
Expand Down Expand Up @@ -344,10 +347,11 @@ func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, err

// ServicePrincipalToken encapsulates a Token created for a Service Principal.
type ServicePrincipalToken struct {
inner servicePrincipalToken
refreshLock *sync.RWMutex
sender Sender
refreshCallbacks []TokenRefreshCallback
inner servicePrincipalToken
refreshLock *sync.RWMutex
sender Sender
customRefreshFunc TokenRefresh
refreshCallbacks []TokenRefreshCallback
// MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token.
MaxMSIRefreshAttempts int
}
Expand All @@ -362,6 +366,11 @@ func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCa
spt.refreshCallbacks = callbacks
}

// SetCustomRefreshFunc sets a custom refresh function used to refresh the token.
func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) {
spt.customRefreshFunc = customRefreshFunc
}

// MarshalJSON implements the json.Marshaler interface.
func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) {
return json.Marshal(spt.inner)
Expand Down Expand Up @@ -833,6 +842,15 @@ func isIMDS(u url.URL) bool {
}

func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error {
if spt.customRefreshFunc != nil {
token, err := spt.customRefreshFunc(ctx, resource)
if err != nil {
return err
}
spt.inner.Token = *token
return spt.InvokeRefreshCallbacks(spt.inner.Token)
}

req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil)
if err != nil {
return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err)
Expand Down
38 changes: 38 additions & 0 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) {
}
}

func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) {
spt := newServicePrincipalToken()

var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
return nil, nil
}

if spt.customRefreshFunc != nil {
t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't")
}

spt.SetCustomRefreshFunc(refreshFunc)

if spt.customRefreshFunc == nil {
t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func")
}
}

func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) {
spt := newServicePrincipalToken()

Expand All @@ -123,6 +141,26 @@ func TestServicePrincipalTokenSetSender(t *testing.T) {
}
}

func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) {
spt := newServicePrincipalToken()

called := false
var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) {
called = true
return &Token{}, nil
}
spt.SetCustomRefreshFunc(refreshFunc)
if called {
t.Fatalf("adal: ServicePrincipalToken#refreshInternal called the refresh function prior to refreshing")
}

spt.refreshInternal(context.Background(), "https://example.com")

if !called {
t.Fatalf("adal: ServicePrincipalToken#refreshInternal didn't call the refresh function")
}
}

func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) {
spt := newServicePrincipalToken()

Expand Down