diff --git a/go.mod b/go.mod index 948e959b..031ecc70 100644 --- a/go.mod +++ b/go.mod @@ -10,12 +10,12 @@ require ( github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 github.com/golang-jwt/jwt/v4 v4.5.0 github.com/golang-jwt/jwt/v5 v5.0.0 - github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.4.0 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 + go.uber.org/mock v0.3.0 golang.org/x/crypto v0.14.0 gopkg.in/dnaeon/go-vcr.v3 v3.1.2 gopkg.in/retry.v1 v1.0.3 diff --git a/go.sum b/go.sum index 9a123c82..bf8dd9a6 100644 --- a/go.sum +++ b/go.sum @@ -63,8 +63,6 @@ github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJ github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -169,10 +167,11 @@ github.com/xlab/treeprint v1.2.0 h1:HzHnuAF1plUN2zGlAFHbSQP2qJ0ZAD3XF5XD7OesXRQ= github.com/xlab/treeprint v1.2.0/go.mod h1:gj5Gd3gPdKtR1ikdDK6fnFLdmIS0X30kTTuNd/WEJu0= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.starlark.net v0.0.0-20230525235612-a134d8f9ddca h1:VdD38733bfYv5tUZwEIskMM93VanwNIi5bIKnDrJdEY= go.starlark.net v0.0.0-20230525235612-a134d8f9ddca/go.mod h1:jxU+3+j+71eXOW14274+SmmuW82qJzl6iZSeqEtTGds= +go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= +go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -187,7 +186,6 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -199,7 +197,6 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -213,7 +210,6 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -222,9 +218,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -257,7 +251,6 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/cmd/token.go b/pkg/cmd/token.go index cd407bc0..4ce28511 100644 --- a/pkg/cmd/token.go +++ b/pkg/cmd/token.go @@ -1,6 +1,10 @@ package cmd import ( + "context" + "os" + "os/signal" + "github.com/Azure/kubelogin/pkg/internal/token" "github.com/spf13/cobra" ) @@ -16,6 +20,10 @@ func newTokenCmd() *cobra.Command { RunE: func(c *cobra.Command, args []string) error { o.UpdateFromEnv() + ctx := context.Background() + ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) + defer cancel() + if err := o.Validate(); err != nil { return err } @@ -24,7 +32,7 @@ func newTokenCmd() *cobra.Command { if err != nil { return err } - if err := plugin.Do(); err != nil { + if err := plugin.Do(ctx); err != nil { return err } return nil diff --git a/pkg/internal/token/azurecli.go b/pkg/internal/token/azurecli.go index caf64827..e661629f 100644 --- a/pkg/internal/token/azurecli.go +++ b/pkg/internal/token/azurecli.go @@ -40,7 +40,7 @@ func newAzureCLIToken(resourceID string, tenantID string, timeout time.Duration) } // Token fetches an azcore.AccessToken from the Azure CLI SDK and converts it to an adal.Token for use with kubelogin. -func (p *AzureCLIToken) Token() (adal.Token, error) { +func (p *AzureCLIToken) Token(ctx context.Context) (adal.Token, error) { emptyToken := adal.Token{} // Request a new Azure CLI token provider @@ -51,7 +51,7 @@ func (p *AzureCLIToken) Token() (adal.Token, error) { return emptyToken, fmt.Errorf("unable to create credential. Received: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + ctx, cancel := context.WithTimeout(ctx, p.timeout) defer cancel() // Use the token provider to get a new token with the new context diff --git a/pkg/internal/token/azurecli_test.go b/pkg/internal/token/azurecli_test.go index f92707cb..8536b8a1 100644 --- a/pkg/internal/token/azurecli_test.go +++ b/pkg/internal/token/azurecli_test.go @@ -1,6 +1,7 @@ package token import ( + "context" "testing" "github.com/Azure/kubelogin/pkg/internal/testutils" @@ -17,7 +18,7 @@ func TestNewAzureCLITokenEmpty(t *testing.T) { func TestNewAzureCLIToken(t *testing.T) { azcli := AzureCLIToken{} - _, err := azcli.Token() + _, err := azcli.Token(context.TODO()) if !testutils.ErrorContains(err, "expected an empty error but received:") { t.Errorf("unexpected error: %v", err) diff --git a/pkg/internal/token/devicecode.go b/pkg/internal/token/devicecode.go index a5c0c5a5..af42f42f 100644 --- a/pkg/internal/token/devicecode.go +++ b/pkg/internal/token/devicecode.go @@ -1,6 +1,7 @@ package token import ( + "context" "errors" "fmt" "os" @@ -35,7 +36,7 @@ func newDeviceCodeTokenProvider(oAuthConfig adal.OAuthConfig, clientID, resource }, nil } -func (p *deviceCodeTokenProvider) Token() (adal.Token, error) { +func (p *deviceCodeTokenProvider) Token(ctx context.Context) (adal.Token, error) { emptyToken := adal.Token{} client := &autorest.Client{} deviceCode, err := adal.InitiateDeviceAuth(client, p.oAuthConfig, p.clientID, p.resourceID) @@ -48,7 +49,7 @@ func (p *deviceCodeTokenProvider) Token() (adal.Token, error) { return emptyToken, fmt.Errorf("prompting the device code message: %s", err) } - token, err := adal.WaitForUserCompletion(client, deviceCode) + token, err := adal.WaitForUserCompletionWithContext(ctx, client, deviceCode) if err != nil { return emptyToken, fmt.Errorf("waiting for device code authentication to complete: %s", err) } diff --git a/pkg/internal/token/devicecode_test.go b/pkg/internal/token/devicecode_test.go index 9f4505dc..812b004e 100644 --- a/pkg/internal/token/devicecode_test.go +++ b/pkg/internal/token/devicecode_test.go @@ -1,6 +1,7 @@ package token import ( + "context" "fmt" "strings" "testing" @@ -50,7 +51,7 @@ func TestNewDeviceCodeTokenProviderEmpty(t *testing.T) { func TestNewDeviceCodeToken(t *testing.T) { deviceCode := deviceCodeTokenProvider{} - _, err := deviceCode.Token() + _, err := deviceCode.Token(context.TODO()) if !testutils.ErrorContains(err, "initialing the device code authentication:") { t.Errorf("unexpected error: %v", err) diff --git a/pkg/internal/token/execCredentialPlugin.go b/pkg/internal/token/execCredentialPlugin.go index 4490ed13..675ed6c4 100644 --- a/pkg/internal/token/execCredentialPlugin.go +++ b/pkg/internal/token/execCredentialPlugin.go @@ -3,6 +3,7 @@ package token //go:generate sh -c "mockgen -destination mock_$GOPACKAGE/execCredentialPlugin.go github.com/Azure/kubelogin/pkg/internal/token ExecCredentialPlugin" import ( + "context" "fmt" "os" "time" @@ -16,7 +17,7 @@ const ( ) type ExecCredentialPlugin interface { - Do() error + Do(ctx context.Context) error } type execCredentialPlugin struct { @@ -31,7 +32,7 @@ type execCredentialPlugin struct { func New(o *Options) (ExecCredentialPlugin, error) { klog.V(10).Info(o.ToString()) - provider, err := newTokenProvider(o) + provider, err := NewTokenProvider(o) if err != nil { return nil, err } @@ -49,7 +50,7 @@ func New(o *Options) (ExecCredentialPlugin, error) { }, nil } -func (p *execCredentialPlugin) Do() error { +func (p *execCredentialPlugin) Do(ctx context.Context) error { var ( token adal.Token err error @@ -87,7 +88,7 @@ func (p *execCredentialPlugin) Do() error { return fmt.Errorf("failed to get refresher: %s", err) } klog.V(5).Info("refresh token") - token, err := refresher.Token() + token, err := refresher.Token(ctx) // if refresh fails, we will login using token provider if err != nil { klog.V(5).Infof("refresh failed, will continue to login: %s", err) @@ -98,7 +99,7 @@ func (p *execCredentialPlugin) Do() error { if tokenRefreshed { klog.V(10).Info("token refreshed") - // if refresh succeeds, save tooken, and return + // if refresh succeeds, save token, and return if err := p.tokenCache.Write(p.o.tokenCacheFile, token); err != nil { return fmt.Errorf("failed to write to store: %s", err) } @@ -112,7 +113,7 @@ func (p *execCredentialPlugin) Do() error { klog.V(5).Info("acquire new token") // run the underlying provider - token, err = p.provider.Token() + token, err = p.provider.Token(ctx) if err != nil { return fmt.Errorf("failed to get token: %s", err) } diff --git a/pkg/internal/token/execCredentialPlugin_test.go b/pkg/internal/token/execCredentialPlugin_test.go index 091c4969..4053ad9e 100644 --- a/pkg/internal/token/execCredentialPlugin_test.go +++ b/pkg/internal/token/execCredentialPlugin_test.go @@ -1,6 +1,7 @@ package token import ( + "context" "encoding/json" "errors" "fmt" @@ -10,7 +11,7 @@ import ( "github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/kubelogin/pkg/internal/token/mock_token" - "github.com/golang/mock/gomock" + "go.uber.org/mock/gomock" ) func TestExecCredentialPlugin(t *testing.T) { @@ -45,7 +46,7 @@ func TestExecCredentialPlugin(t *testing.T) { }, setupExpectations: func(tc testContext) { tc.tokenCache.EXPECT().Read(cacheFile).Return(adal.Token{}, nil) - tc.tokenProvider.EXPECT().Token().Return(adal.Token{}, nil) + tc.tokenProvider.EXPECT().Token(gomock.Any()).Return(adal.Token{}, nil) tc.tokenCache.EXPECT().Write(cacheFile, adal.Token{}).Return(nil) tc.pluginWriter.EXPECT().Write(adal.Token{}, os.Stdout) }, @@ -97,7 +98,7 @@ func TestExecCredentialPlugin(t *testing.T) { ExpiresOn: json.Number(fmt.Sprintf("%d", time.Now().AddDate(1, 0, 0).Unix())), } tc.tokenCache.EXPECT().Read(cacheFile).Return(cachedToken, nil) - tc.tokenProvider.EXPECT().Token().Return(refreshedToken, nil) + tc.tokenProvider.EXPECT().Token(gomock.Any()).Return(refreshedToken, nil) tc.tokenCache.EXPECT().Write(cacheFile, refreshedToken).Return(nil) tc.pluginWriter.EXPECT().Write(refreshedToken, os.Stdout) }, @@ -124,8 +125,9 @@ func TestExecCredentialPlugin(t *testing.T) { execCredentialWriter: pluginWriter, } + ctx := context.TODO() errMessage := "" - if err := plugin.Do(); err != nil { + if err := plugin.Do(ctx); err != nil { errMessage = err.Error() } if errMessage != data.expectedError { diff --git a/pkg/internal/token/federatedIdentity.go b/pkg/internal/token/federatedIdentity.go index 31d7f8ca..1753316c 100644 --- a/pkg/internal/token/federatedIdentity.go +++ b/pkg/internal/token/federatedIdentity.go @@ -67,7 +67,7 @@ func newWorkloadIdentityToken(clientID, federatedTokenFile, authorityHost, serve }, nil } -func (p *workloadIdentityToken) Token() (adal.Token, error) { +func (p *workloadIdentityToken) Token(ctx context.Context) (adal.Token, error) { emptyToken := adal.Token{} resource := strings.TrimSuffix(p.serverID, "/") @@ -76,7 +76,7 @@ func (p *workloadIdentityToken) Token() (adal.Token, error) { resource += defaultScope } - result, err := p.client.AcquireTokenByCredential(context.Background(), []string{resource}) + result, err := p.client.AcquireTokenByCredential(ctx, []string{resource}) if err != nil { return emptyToken, fmt.Errorf("failed to acquire token. %s", err) } diff --git a/pkg/internal/token/interactive.go b/pkg/internal/token/interactive.go index bf00943a..a3d2bebd 100644 --- a/pkg/internal/token/interactive.go +++ b/pkg/internal/token/interactive.go @@ -46,12 +46,11 @@ func newInteractiveTokenProvider(oAuthConfig adal.OAuthConfig, clientID, resourc } // Token fetches an azcore.AccessToken from the interactive browser SDK and converts it to an adal.Token for use with kubelogin. -func (p *InteractiveToken) Token() (adal.Token, error) { - return p.TokenWithOptions(nil) +func (p *InteractiveToken) Token(ctx context.Context) (adal.Token, error) { + return p.TokenWithOptions(ctx, nil) } -func (p *InteractiveToken) TokenWithOptions(options *azcore.ClientOptions) (adal.Token, error) { - ctx := context.Background() +func (p *InteractiveToken) TokenWithOptions(ctx context.Context, options *azcore.ClientOptions) (adal.Token, error) { emptyToken := adal.Token{} // Request a new Interactive token provider diff --git a/pkg/internal/token/manualtoken.go b/pkg/internal/token/manualtoken.go index a5977bbc..6afe331d 100644 --- a/pkg/internal/token/manualtoken.go +++ b/pkg/internal/token/manualtoken.go @@ -1,6 +1,7 @@ package token import ( + "context" "errors" "fmt" @@ -40,7 +41,7 @@ func newManualToken(oAuthConfig adal.OAuthConfig, clientID, resourceID, tenantID return provider, nil } -func (p *manualToken) Token() (adal.Token, error) { +func (p *manualToken) Token(ctx context.Context) (adal.Token, error) { emptyToken := adal.Token{} callback := func(t adal.Token) error { return nil @@ -55,7 +56,7 @@ func (p *manualToken) Token() (adal.Token, error) { return emptyToken, fmt.Errorf("failed to create service principal from manual token for token refresh: %s", err) } - err = spt.Refresh() + err = spt.RefreshWithContext(ctx) if err != nil { return emptyToken, err } diff --git a/pkg/internal/token/manualtoken_test.go b/pkg/internal/token/manualtoken_test.go index feed9ec9..acd885c0 100644 --- a/pkg/internal/token/manualtoken_test.go +++ b/pkg/internal/token/manualtoken_test.go @@ -1,6 +1,7 @@ package token import ( + "context" "errors" "testing" @@ -87,7 +88,7 @@ func TestManualTokenToken(t *testing.T) { provider, _ := newManualToken(oAuthConfig, clientID, resourceID, tenantID, token) // Test successful token refresh - if _, err := provider.Token(); err == nil { + if _, err := provider.Token(context.TODO()); err == nil { if err == nil { t.Errorf("Expected no error, but got %v", err) } diff --git a/pkg/internal/token/mock_token/execCredentialPlugin.go b/pkg/internal/token/mock_token/execCredentialPlugin.go new file mode 100644 index 00000000..d2cba8c3 --- /dev/null +++ b/pkg/internal/token/mock_token/execCredentialPlugin.go @@ -0,0 +1,53 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/Azure/kubelogin/pkg/internal/token (interfaces: ExecCredentialPlugin) +// +// Generated by this command: +// +// mockgen -destination mock_token/execCredentialPlugin.go github.com/Azure/kubelogin/pkg/internal/token ExecCredentialPlugin +// +// Package mock_token is a generated GoMock package. +package mock_token + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockExecCredentialPlugin is a mock of ExecCredentialPlugin interface. +type MockExecCredentialPlugin struct { + ctrl *gomock.Controller + recorder *MockExecCredentialPluginMockRecorder +} + +// MockExecCredentialPluginMockRecorder is the mock recorder for MockExecCredentialPlugin. +type MockExecCredentialPluginMockRecorder struct { + mock *MockExecCredentialPlugin +} + +// NewMockExecCredentialPlugin creates a new mock instance. +func NewMockExecCredentialPlugin(ctrl *gomock.Controller) *MockExecCredentialPlugin { + mock := &MockExecCredentialPlugin{ctrl: ctrl} + mock.recorder = &MockExecCredentialPluginMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockExecCredentialPlugin) EXPECT() *MockExecCredentialPluginMockRecorder { + return m.recorder +} + +// Do mocks base method. +func (m *MockExecCredentialPlugin) Do(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Do", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Do indicates an expected call of Do. +func (mr *MockExecCredentialPluginMockRecorder) Do(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockExecCredentialPlugin)(nil).Do), arg0) +} diff --git a/pkg/internal/token/mock_token/execCredentialWriter.go b/pkg/internal/token/mock_token/execCredentialWriter.go index e4e9c466..f12a0c25 100644 --- a/pkg/internal/token/mock_token/execCredentialWriter.go +++ b/pkg/internal/token/mock_token/execCredentialWriter.go @@ -1,14 +1,19 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/Azure/kubelogin/pkg/internal/token (interfaces: ExecCredentialWriter) - +// +// Generated by this command: +// +// mockgen -destination mock_token/execCredentialWriter.go github.com/Azure/kubelogin/pkg/internal/token ExecCredentialWriter +// // Package mock_token is a generated GoMock package. package mock_token import ( - adal "github.com/Azure/go-autorest/autorest/adal" - gomock "github.com/golang/mock/gomock" io "io" reflect "reflect" + + adal "github.com/Azure/go-autorest/autorest/adal" + gomock "go.uber.org/mock/gomock" ) // MockExecCredentialWriter is a mock of ExecCredentialWriter interface. @@ -43,7 +48,7 @@ func (m *MockExecCredentialWriter) Write(arg0 adal.Token, arg1 io.Writer) error } // Write indicates an expected call of Write. -func (mr *MockExecCredentialWriterMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockExecCredentialWriterMockRecorder) Write(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockExecCredentialWriter)(nil).Write), arg0, arg1) } diff --git a/pkg/internal/token/mock_token/provider.go b/pkg/internal/token/mock_token/provider.go index f42f0ea9..ac804e02 100644 --- a/pkg/internal/token/mock_token/provider.go +++ b/pkg/internal/token/mock_token/provider.go @@ -1,13 +1,19 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/Azure/kubelogin/pkg/internal/token (interfaces: TokenProvider) - +// +// Generated by this command: +// +// mockgen -destination mock_token/provider.go github.com/Azure/kubelogin/pkg/internal/token TokenProvider +// // Package mock_token is a generated GoMock package. package mock_token import ( - adal "github.com/Azure/go-autorest/autorest/adal" - gomock "github.com/golang/mock/gomock" + context "context" reflect "reflect" + + adal "github.com/Azure/go-autorest/autorest/adal" + gomock "go.uber.org/mock/gomock" ) // MockTokenProvider is a mock of TokenProvider interface. @@ -34,16 +40,16 @@ func (m *MockTokenProvider) EXPECT() *MockTokenProviderMockRecorder { } // Token mocks base method. -func (m *MockTokenProvider) Token() (adal.Token, error) { +func (m *MockTokenProvider) Token(arg0 context.Context) (adal.Token, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Token") + ret := m.ctrl.Call(m, "Token", arg0) ret0, _ := ret[0].(adal.Token) ret1, _ := ret[1].(error) return ret0, ret1 } // Token indicates an expected call of Token. -func (mr *MockTokenProviderMockRecorder) Token() *gomock.Call { +func (mr *MockTokenProviderMockRecorder) Token(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Token", reflect.TypeOf((*MockTokenProvider)(nil).Token)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Token", reflect.TypeOf((*MockTokenProvider)(nil).Token), arg0) } diff --git a/pkg/internal/token/mock_token/tokenCache.go b/pkg/internal/token/mock_token/tokenCache.go index f4014287..9317c0d3 100644 --- a/pkg/internal/token/mock_token/tokenCache.go +++ b/pkg/internal/token/mock_token/tokenCache.go @@ -1,13 +1,18 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/Azure/kubelogin/pkg/internal/token (interfaces: TokenCache) - +// +// Generated by this command: +// +// mockgen -destination mock_token/tokenCache.go github.com/Azure/kubelogin/pkg/internal/token TokenCache +// // Package mock_token is a generated GoMock package. package mock_token import ( - adal "github.com/Azure/go-autorest/autorest/adal" - gomock "github.com/golang/mock/gomock" reflect "reflect" + + adal "github.com/Azure/go-autorest/autorest/adal" + gomock "go.uber.org/mock/gomock" ) // MockTokenCache is a mock of TokenCache interface. @@ -43,7 +48,7 @@ func (m *MockTokenCache) Read(arg0 string) (adal.Token, error) { } // Read indicates an expected call of Read. -func (mr *MockTokenCacheMockRecorder) Read(arg0 interface{}) *gomock.Call { +func (mr *MockTokenCacheMockRecorder) Read(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockTokenCache)(nil).Read), arg0) } @@ -57,7 +62,7 @@ func (m *MockTokenCache) Write(arg0 string, arg1 adal.Token) error { } // Write indicates an expected call of Write. -func (mr *MockTokenCacheMockRecorder) Write(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockTokenCacheMockRecorder) Write(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockTokenCache)(nil).Write), arg0, arg1) } diff --git a/pkg/internal/token/msi.go b/pkg/internal/token/msi.go index 433ea8d9..5bc6f2bd 100644 --- a/pkg/internal/token/msi.go +++ b/pkg/internal/token/msi.go @@ -1,6 +1,7 @@ package token import ( + "context" "errors" "fmt" @@ -27,7 +28,7 @@ func newManagedIdentityToken(clientID, identityResourceID, resourceID string) (T return provider, nil } -func (p *managedIdentityToken) Token() (adal.Token, error) { +func (p *managedIdentityToken) Token(ctx context.Context) (adal.Token, error) { var ( spt *adal.ServicePrincipalToken err error @@ -77,7 +78,7 @@ func (p *managedIdentityToken) Token() (adal.Token, error) { } } - err = spt.Refresh() + err = spt.RefreshWithContext(ctx) if err != nil { return emptyToken, err } diff --git a/pkg/internal/token/provider.go b/pkg/internal/token/provider.go index fc506d27..c29fa653 100644 --- a/pkg/internal/token/provider.go +++ b/pkg/internal/token/provider.go @@ -3,6 +3,7 @@ package token //go:generate sh -c "mockgen -destination mock_$GOPACKAGE/provider.go github.com/Azure/kubelogin/pkg/internal/token TokenProvider" import ( + "context" "errors" "fmt" @@ -12,10 +13,11 @@ import ( ) type TokenProvider interface { - Token() (adal.Token, error) + Token(ctx context.Context) (adal.Token, error) } -func newTokenProvider(o *Options) (TokenProvider, error) { +// NewTokenProvider creates the TokenProvider instance with giving options. +func NewTokenProvider(o *Options) (TokenProvider, error) { oAuthConfig, err := getOAuthConfig(o.Environment, o.TenantID, o.IsLegacy) if err != nil { return nil, fmt.Errorf("failed to get oAuthConfig. isLegacy: %t, err: %s", o.IsLegacy, err) diff --git a/pkg/internal/token/provider_test.go b/pkg/internal/token/provider_test.go index 20168fa7..91178fef 100644 --- a/pkg/internal/token/provider_test.go +++ b/pkg/internal/token/provider_test.go @@ -8,13 +8,13 @@ import ( ) func TestNewTokenProvider(t *testing.T) { - t.Run("newTokenProvider should return error on failure to get oAuthConfig", func(t *testing.T) { + t.Run("NewTokenProvider should return error on failure to get oAuthConfig", func(t *testing.T) { options := &Options{ Environment: "badenvironment", TenantID: "testtenant", IsLegacy: false, } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err == nil || provider != nil { t.Errorf("expected error but got nil") } @@ -23,14 +23,14 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return error on failure to parse PoP claims", func(t *testing.T) { + t.Run("NewTokenProvider should return error on failure to parse PoP claims", func(t *testing.T) { options := &Options{ TenantID: "testtenant", IsLegacy: false, IsPoPTokenEnabled: true, PoPTokenClaims: "1=2", } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err == nil || provider != nil { t.Errorf("expected error but got nil") } @@ -39,11 +39,11 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return error on invalid login method", func(t *testing.T) { + t.Run("NewTokenProvider should return error on invalid login method", func(t *testing.T) { options := &Options{ LoginMethod: "unsupported", } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err == nil || provider != nil { t.Errorf("expected error but got nil") } @@ -52,7 +52,7 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return interactive token provider with correct fields", func(t *testing.T) { + t.Run("NewTokenProvider should return interactive token provider with correct fields", func(t *testing.T) { options := &Options{ TenantID: "testtenant", ClientID: "testclient", @@ -61,7 +61,7 @@ func TestNewTokenProvider(t *testing.T) { PoPTokenClaims: "u=testhost", LoginMethod: "interactive", } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err != nil || provider == nil { t.Errorf("expected no error but got: %s", err) } @@ -81,7 +81,7 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return SPN token provider using client secret with correct fields", func(t *testing.T) { + t.Run("NewTokenProvider should return SPN token provider using client secret with correct fields", func(t *testing.T) { options := &Options{ TenantID: "testtenant", ClientID: "testclient", @@ -91,7 +91,7 @@ func TestNewTokenProvider(t *testing.T) { PoPTokenClaims: "u=testhost, 1=2", LoginMethod: "spn", } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err != nil || provider == nil { t.Errorf("expected no error but got: %s", err) } @@ -114,7 +114,7 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return SPN token provider using client cert with correct fields", func(t *testing.T) { + t.Run("NewTokenProvider should return SPN token provider using client cert with correct fields", func(t *testing.T) { options := &Options{ TenantID: "testtenant", ClientID: "testclient", @@ -123,7 +123,7 @@ func TestNewTokenProvider(t *testing.T) { ClientCertPassword: "testcertpass", LoginMethod: "spn", } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err != nil || provider == nil { t.Errorf("expected no error but got: %s", err) } @@ -148,7 +148,7 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return resource owner token provider with correct fields", func(t *testing.T) { + t.Run("NewTokenProvider should return resource owner token provider with correct fields", func(t *testing.T) { options := &Options{ TenantID: "testtenant", ClientID: "testclient", @@ -157,7 +157,7 @@ func TestNewTokenProvider(t *testing.T) { Password: "testpass", LoginMethod: "ropc", } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err != nil || provider == nil { t.Errorf("expected no error but got: %s", err) } @@ -179,14 +179,14 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return resource owner token provider with correct fields", func(t *testing.T) { + t.Run("NewTokenProvider should return resource owner token provider with correct fields", func(t *testing.T) { options := &Options{ ClientID: "testclient", ServerID: "testserver", IdentityResourceID: "testidentity", LoginMethod: "msi", } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err != nil || provider == nil { t.Errorf("expected no error but got: %s", err) } @@ -202,13 +202,13 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return azure CLI token provider with correct fields", func(t *testing.T) { + t.Run("NewTokenProvider should return azure CLI token provider with correct fields", func(t *testing.T) { options := &Options{ ServerID: "testserver", TenantID: "testtenant", LoginMethod: "azurecli", } - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err != nil || provider == nil { t.Errorf("expected no error but got: %s", err) } @@ -221,7 +221,7 @@ func TestNewTokenProvider(t *testing.T) { } }) - t.Run("newTokenProvider should return workload identity token provider with correct fields", func(t *testing.T) { + t.Run("NewTokenProvider should return workload identity token provider with correct fields", func(t *testing.T) { options := &Options{ TenantID: "testtenant", ClientID: "testclient", @@ -231,7 +231,7 @@ func TestNewTokenProvider(t *testing.T) { LoginMethod: "workloadidentity", } t.Run("with token file", func(t *testing.T) { - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err != nil || provider == nil { t.Errorf("expected no error but got: %s", err) } @@ -244,7 +244,7 @@ func TestNewTokenProvider(t *testing.T) { options.FederatedTokenFile = "" t.Setenv(actionsIDTokenRequestToken, "fake-token") t.Setenv(actionsIDTokenRequestURL, "fake-url") - provider, err := newTokenProvider(options) + provider, err := NewTokenProvider(options) if err != nil || provider == nil { t.Errorf("expected no error but got: %s", err) } diff --git a/pkg/internal/token/ropc.go b/pkg/internal/token/ropc.go index 011f151a..eb2fdf89 100644 --- a/pkg/internal/token/ropc.go +++ b/pkg/internal/token/ropc.go @@ -1,6 +1,7 @@ package token import ( + "context" "errors" "fmt" @@ -43,7 +44,7 @@ func newResourceOwnerToken(oAuthConfig adal.OAuthConfig, clientID, username, pas }, nil } -func (p *resourceOwnerToken) Token() (adal.Token, error) { +func (p *resourceOwnerToken) Token(ctx context.Context) (adal.Token, error) { emptyToken := adal.Token{} callback := func(t adal.Token) error { return nil @@ -59,7 +60,7 @@ func (p *resourceOwnerToken) Token() (adal.Token, error) { return emptyToken, fmt.Errorf("failed to create service principal token from username password: %s", err) } - err = spt.Refresh() + err = spt.RefreshWithContext(ctx) if err != nil { return emptyToken, err } diff --git a/pkg/internal/token/serviceprincipaltoken.go b/pkg/internal/token/serviceprincipaltoken.go index 4bd2c90f..5c039991 100644 --- a/pkg/internal/token/serviceprincipaltoken.go +++ b/pkg/internal/token/serviceprincipaltoken.go @@ -67,12 +67,11 @@ func newServicePrincipalTokenProvider( } // Token fetches an azcore.AccessToken from the Azure SDK and converts it to an adal.Token for use with kubelogin. -func (p *servicePrincipalToken) Token() (adal.Token, error) { - return p.TokenWithOptions(nil) +func (p *servicePrincipalToken) Token(ctx context.Context) (adal.Token, error) { + return p.TokenWithOptions(ctx, nil) } -func (p *servicePrincipalToken) TokenWithOptions(options *azcore.ClientOptions) (adal.Token, error) { - ctx := context.Background() +func (p *servicePrincipalToken) TokenWithOptions(ctx context.Context, options *azcore.ClientOptions) (adal.Token, error) { emptyToken := adal.Token{} var accessToken string var expirationTimeUnix int64 diff --git a/pkg/internal/token/serviceprincipaltoken_legacy.go b/pkg/internal/token/serviceprincipaltoken_legacy.go index bdb46612..640df68e 100644 --- a/pkg/internal/token/serviceprincipaltoken_legacy.go +++ b/pkg/internal/token/serviceprincipaltoken_legacy.go @@ -1,6 +1,7 @@ package token import ( + "context" "errors" "fmt" "os" @@ -51,7 +52,7 @@ func newLegacyServicePrincipalToken(oAuthConfig adal.OAuthConfig, clientID, clie }, nil } -func (p *legacyServicePrincipalToken) Token() (adal.Token, error) { +func (p *legacyServicePrincipalToken) Token(ctx context.Context) (adal.Token, error) { emptyToken := adal.Token{} var ( @@ -91,7 +92,7 @@ func (p *legacyServicePrincipalToken) Token() (adal.Token, error) { } } - err = spt.EnsureFresh() + err = spt.EnsureFreshWithContext(ctx) if err != nil { return emptyToken, err } diff --git a/pkg/internal/token/serviceprincipaltoken_legacy_test.go b/pkg/internal/token/serviceprincipaltoken_legacy_test.go index 19260585..1bb1ba48 100644 --- a/pkg/internal/token/serviceprincipaltoken_legacy_test.go +++ b/pkg/internal/token/serviceprincipaltoken_legacy_test.go @@ -8,7 +8,7 @@ import ( func TestNewLegacyServicePrincipalToken(t *testing.T) { t.Run("new spn token provider with legacy should not result in error", func(t *testing.T) { - _, err := newTokenProvider(&Options{ + _, err := NewTokenProvider(&Options{ LoginMethod: ServicePrincipalLogin, IsLegacy: true, TenantID: "tenantID", diff --git a/pkg/internal/token/serviceprincipaltoken_test.go b/pkg/internal/token/serviceprincipaltoken_test.go index a3392002..8813b0d1 100644 --- a/pkg/internal/token/serviceprincipaltoken_test.go +++ b/pkg/internal/token/serviceprincipaltoken_test.go @@ -1,6 +1,7 @@ package token import ( + "context" "fmt" "os" "testing" @@ -135,7 +136,7 @@ func TestNewServicePrincipalTokenProvider(t *testing.T) { func TestMissingLoginMethods(t *testing.T) { p := &servicePrincipalToken{} expectedErr := "service principal token requires either client secret or certificate" - _, err := p.Token() + _, err := p.Token(context.TODO()) if !testutils.ErrorContains(err, expectedErr) { t.Errorf("expected error %s, but got %s", expectedErr, err) } @@ -228,7 +229,7 @@ func TestServicePrincipalTokenVCR(t *testing.T) { Transport: httpClient, } - token, err := tc.p.TokenWithOptions(&clientOpts) + token, err := tc.p.TokenWithOptions(context.TODO(), &clientOpts) defer vcrRecorder.Stop() if err != nil { if !testutils.ErrorContains(err, tc.expectedError.Error()) { @@ -349,7 +350,7 @@ func TestServicePrincipalPoPTokenVCR(t *testing.T) { Transport: httpClient, } - token, err = tc.p.TokenWithOptions(&clientOpts) + token, err = tc.p.TokenWithOptions(context.TODO(), &clientOpts) defer vcrRecorder.Stop() if err != nil { if !testutils.ErrorContains(err, tc.expectedError.Error()) { diff --git a/pkg/internal/token/serviceprincipaltokencertificate_test.go b/pkg/internal/token/serviceprincipaltokencertificate_test.go index 68f64ed3..c8f1e197 100644 --- a/pkg/internal/token/serviceprincipaltokencertificate_test.go +++ b/pkg/internal/token/serviceprincipaltokencertificate_test.go @@ -1,6 +1,7 @@ package token import ( + "context" "testing" "github.com/Azure/kubelogin/pkg/internal/testutils" @@ -12,7 +13,7 @@ func TestMissingCertFile(t *testing.T) { } expectedErr := "failed to read the certificate file" - _, err := p.Token() + _, err := p.Token(context.TODO()) if !testutils.ErrorContains(err, expectedErr) { t.Errorf("expected error %s, but got %s", expectedErr, err) } @@ -25,7 +26,7 @@ func TestBadCertPassword(t *testing.T) { } expectedErr := "failed to decode pkcs12 certificate while creating spt: pkcs12: decryption password incorrect" - _, err := p.Token() + _, err := p.Token(context.TODO()) if !testutils.ErrorContains(err, expectedErr) { t.Errorf("expected error %s, but got %s", expectedErr, err) }