From 39d1148136d6d225d58aa2697aa2f23dc8d45927 Mon Sep 17 00:00:00 2001 From: Jon Johnson Date: Thu, 16 May 2024 14:24:54 -0700 Subject: [PATCH] Add Context support to auth methods (#1949) Signed-off-by: Jon Johnson --- cmd/crane/cmd/auth.go | 8 +++--- pkg/authn/authn.go | 17 +++++++++++++ pkg/authn/keychain.go | 41 ++++++++++++++++++++++++++++--- pkg/authn/multikeychain.go | 8 +++++- pkg/v1/google/auth.go | 18 ++++++++------ pkg/v1/google/auth_test.go | 15 ++++++----- pkg/v1/google/keychain.go | 14 ++++++++--- pkg/v1/remote/fetcher.go | 2 +- pkg/v1/remote/transport/basic.go | 2 +- pkg/v1/remote/transport/bearer.go | 6 ++--- pkg/v1/remote/write.go | 2 +- 11 files changed, 100 insertions(+), 33 deletions(-) diff --git a/cmd/crane/cmd/auth.go b/cmd/crane/cmd/auth.go index a01e6235e..433d4195b 100644 --- a/cmd/crane/cmd/auth.go +++ b/cmd/crane/cmd/auth.go @@ -73,7 +73,7 @@ $ curl -H "$(crane auth token -H ubuntu)" https://index.docker.io/v2/library/ubu return err } - auth, err := o.Keychain.Resolve(repo) + auth, err := authn.Resolve(cmd.Context(), o.Keychain, repo) if err != nil { return err } @@ -152,7 +152,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command { Short: "Implements a credential helper", Example: eg, Args: cobra.MaximumNArgs(1), - RunE: func(_ *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) error { registryAddr := "" if len(args) == 1 { registryAddr = args[0] @@ -168,7 +168,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command { if err != nil { return err } - authorizer, err := crane.GetOptions(options...).Keychain.Resolve(reg) + authorizer, err := authn.Resolve(cmd.Context(), crane.GetOptions(options...).Keychain, reg) if err != nil { return err } @@ -182,7 +182,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command { os.Exit(1) } - auth, err := authorizer.Authorization() + auth, err := authn.Authorization(cmd.Context(), authorizer) if err != nil { return err } diff --git a/pkg/authn/authn.go b/pkg/authn/authn.go index 172d218e4..1555efae0 100644 --- a/pkg/authn/authn.go +++ b/pkg/authn/authn.go @@ -15,6 +15,7 @@ package authn import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -27,6 +28,22 @@ type Authenticator interface { Authorization() (*AuthConfig, error) } +// ContextAuthenticator is like Authenticator, but allows for context to be passed in. +type ContextAuthenticator interface { + // Authorization returns the value to use in an http transport's Authorization header. + AuthorizationContext(context.Context) (*AuthConfig, error) +} + +// Authorization calls AuthorizationContext with ctx if the given [Authenticator] implements [ContextAuthenticator], +// otherwise it calls Resolve with the given [Resource]. +func Authorization(ctx context.Context, authn Authenticator) (*AuthConfig, error) { + if actx, ok := authn.(ContextAuthenticator); ok { + return actx.AuthorizationContext(ctx) + } + + return authn.Authorization() +} + // AuthConfig contains authorization information for connecting to a Registry // Inlined what we use from github.com/docker/cli/cli/config/types type AuthConfig struct { diff --git a/pkg/authn/keychain.go b/pkg/authn/keychain.go index c16bb1611..f4c452bdc 100644 --- a/pkg/authn/keychain.go +++ b/pkg/authn/keychain.go @@ -15,6 +15,7 @@ package authn import ( + "context" "os" "path/filepath" "sync" @@ -45,6 +46,11 @@ type Keychain interface { Resolve(Resource) (Authenticator, error) } +// ContextKeychain is like Keychain, but allows for context to be passed in. +type ContextKeychain interface { + ResolveContext(context.Context, Resource) (Authenticator, error) +} + // defaultKeychain implements Keychain with the semantics of the standard Docker // credential keychain. type defaultKeychain struct { @@ -62,8 +68,23 @@ const ( DefaultAuthKey = "https://" + name.DefaultRegistry + "/v1/" ) -// Resolve implements Keychain. +// Resolve calls ResolveContext with ctx if the given [Keychain] implements [ContextKeychain], +// otherwise it calls Resolve with the given [Resource]. +func Resolve(ctx context.Context, keychain Keychain, target Resource) (Authenticator, error) { + if rctx, ok := keychain.(ContextKeychain); ok { + return rctx.ResolveContext(ctx, target) + } + + return keychain.Resolve(target) +} + +// ResolveContext implements ContextKeychain. func (dk *defaultKeychain) Resolve(target Resource) (Authenticator, error) { + return dk.ResolveContext(context.Background(), target) +} + +// Resolve implements Keychain. +func (dk *defaultKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) { dk.mu.Lock() defer dk.mu.Unlock() @@ -180,6 +201,10 @@ func NewKeychainFromHelper(h Helper) Keychain { return wrapper{h} } type wrapper struct{ h Helper } func (w wrapper) Resolve(r Resource) (Authenticator, error) { + return w.ResolveContext(context.Background(), r) +} + +func (w wrapper) ResolveContext(ctx context.Context, r Resource) (Authenticator, error) { u, p, err := w.h.Get(r.RegistryStr()) if err != nil { return Anonymous, nil @@ -206,8 +231,12 @@ type refreshingKeychain struct { } func (r *refreshingKeychain) Resolve(target Resource) (Authenticator, error) { + return r.ResolveContext(context.Background(), target) +} + +func (r *refreshingKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) { last := time.Now() - auth, err := r.keychain.Resolve(target) + auth, err := Resolve(ctx, r.keychain, target) if err != nil || auth == Anonymous { return auth, err } @@ -236,17 +265,21 @@ type refreshing struct { } func (r *refreshing) Authorization() (*AuthConfig, error) { + return r.AuthorizationContext(context.Background()) +} + +func (r *refreshing) AuthorizationContext(ctx context.Context) (*AuthConfig, error) { r.Lock() defer r.Unlock() if r.cached == nil || r.expired() { r.last = r.now() - auth, err := r.keychain.Resolve(r.target) + auth, err := Resolve(ctx, r.keychain, r.target) if err != nil { return nil, err } r.cached = auth } - return r.cached.Authorization() + return Authorization(ctx, r.cached) } func (r *refreshing) now() time.Time { diff --git a/pkg/authn/multikeychain.go b/pkg/authn/multikeychain.go index 3b1804f5d..fe241a0fd 100644 --- a/pkg/authn/multikeychain.go +++ b/pkg/authn/multikeychain.go @@ -14,6 +14,8 @@ package authn +import "context" + type multiKeychain struct { keychains []Keychain } @@ -28,8 +30,12 @@ func NewMultiKeychain(kcs ...Keychain) Keychain { // Resolve implements Keychain. func (mk *multiKeychain) Resolve(target Resource) (Authenticator, error) { + return mk.ResolveContext(context.Background(), target) +} + +func (mk *multiKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) { for _, kc := range mk.keychains { - auth, err := kc.Resolve(target) + auth, err := Resolve(ctx, kc, target) if err != nil { return nil, err } diff --git a/pkg/v1/google/auth.go b/pkg/v1/google/auth.go index 11ae39796..4e64eda43 100644 --- a/pkg/v1/google/auth.go +++ b/pkg/v1/google/auth.go @@ -31,7 +31,7 @@ import ( const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" // GetGcloudCmd is exposed so we can test this. -var GetGcloudCmd = func() *exec.Cmd { +var GetGcloudCmd = func(ctx context.Context) *exec.Cmd { // This is odd, but basically what docker-credential-gcr does. // // config-helper is undocumented, but it's purportedly the only supported way @@ -39,15 +39,15 @@ var GetGcloudCmd = func() *exec.Cmd { // // --force-auth-refresh means we are getting a token that is valid for about // an hour (we reuse it until it's expired). - return exec.Command("gcloud", "config", "config-helper", "--force-auth-refresh", "--format=json(credential)") + return exec.CommandContext(ctx, "gcloud", "config", "config-helper", "--force-auth-refresh", "--format=json(credential)") } // NewEnvAuthenticator returns an authn.Authenticator that generates access // tokens from the environment we're running in. // // See: https://godoc.org/golang.org/x/oauth2/google#FindDefaultCredentials -func NewEnvAuthenticator() (authn.Authenticator, error) { - ts, err := googauth.DefaultTokenSource(context.Background(), cloudPlatformScope) +func NewEnvAuthenticator(ctx context.Context) (authn.Authenticator, error) { + ts, err := googauth.DefaultTokenSource(ctx, cloudPlatformScope) if err != nil { return nil, err } @@ -62,14 +62,14 @@ func NewEnvAuthenticator() (authn.Authenticator, error) { // NewGcloudAuthenticator returns an oauth2.TokenSource that generates access // tokens by shelling out to the gcloud sdk. -func NewGcloudAuthenticator() (authn.Authenticator, error) { +func NewGcloudAuthenticator(ctx context.Context) (authn.Authenticator, error) { if _, err := exec.LookPath("gcloud"); err != nil { // gcloud is not available, fall back to anonymous logs.Warn.Println("gcloud binary not found") return authn.Anonymous, nil } - ts := gcloudSource{GetGcloudCmd} + ts := gcloudSource{ctx, GetGcloudCmd} // Attempt to fetch a token to ensure gcloud is installed and we can run it. token, err := ts.Token() @@ -143,13 +143,15 @@ type gcloudOutput struct { } type gcloudSource struct { + ctx context.Context + // This is passed in so that we mock out gcloud and test Token. - exec func() *exec.Cmd + exec func(ctx context.Context) *exec.Cmd } // Token implements oauath2.TokenSource. func (gs gcloudSource) Token() (*oauth2.Token, error) { - cmd := gs.exec() + cmd := gs.exec(gs.ctx) var out bytes.Buffer cmd.Stdout = &out diff --git a/pkg/v1/google/auth_test.go b/pkg/v1/google/auth_test.go index d2974ff13..3b949e9f8 100644 --- a/pkg/v1/google/auth_test.go +++ b/pkg/v1/google/auth_test.go @@ -19,6 +19,7 @@ package google import ( "bytes" + "context" "fmt" "os" "os/exec" @@ -84,15 +85,16 @@ func TestMain(m *testing.M) { } } -func newGcloudCmdMock(env string) func() *exec.Cmd { - return func() *exec.Cmd { - cmd := exec.Command(os.Args[0]) +func newGcloudCmdMock(env string) func(context.Context) *exec.Cmd { + return func(ctx context.Context) *exec.Cmd { + cmd := exec.CommandContext(ctx, os.Args[0]) cmd.Env = []string{fmt.Sprintf("GO_TEST_MODE=%s", env)} return cmd } } func TestGcloudErrors(t *testing.T) { + ctx := context.Background() cases := []struct { env string @@ -113,7 +115,7 @@ func TestGcloudErrors(t *testing.T) { t.Run(tc.env, func(t *testing.T) { GetGcloudCmd = newGcloudCmdMock(tc.env) - if _, err := NewGcloudAuthenticator(); err == nil { + if _, err := NewGcloudAuthenticator(ctx); err == nil { t.Errorf("wanted error, got nil") } else if got := err.Error(); !strings.HasPrefix(got, tc.wantPrefix) { t.Errorf("wanted error prefix %q, got %q", tc.wantPrefix, got) @@ -123,13 +125,14 @@ func TestGcloudErrors(t *testing.T) { } func TestGcloudSuccess(t *testing.T) { + ctx := context.Background() // Stupid coverage to make sure it doesn't panic. var b bytes.Buffer logs.Debug.SetOutput(&b) GetGcloudCmd = newGcloudCmdMock("success") - auth, err := NewGcloudAuthenticator() + auth, err := NewGcloudAuthenticator(ctx) if err != nil { t.Fatalf("NewGcloudAuthenticator got error %v", err) } @@ -263,7 +266,7 @@ func TestNewEnvAuthenticatorFailure(t *testing.T) { } // Expect error. - _, err := NewEnvAuthenticator() + _, err := NewEnvAuthenticator(context.Background()) if err == nil { t.Errorf("expected err, got nil") } diff --git a/pkg/v1/google/keychain.go b/pkg/v1/google/keychain.go index 5472b1879..0645768c8 100644 --- a/pkg/v1/google/keychain.go +++ b/pkg/v1/google/keychain.go @@ -15,6 +15,7 @@ package google import ( + "context" "strings" "sync" @@ -52,26 +53,31 @@ type googleKeychain struct { // In general, we don't worry about that here because we expect to use the same // gcloud configuration in the scope of this one process. func (gk *googleKeychain) Resolve(target authn.Resource) (authn.Authenticator, error) { + return gk.ResolveContext(context.Background(), target) +} + +// ResolveContext implements authn.ContextKeychain. +func (gk *googleKeychain) ResolveContext(ctx context.Context, target authn.Resource) (authn.Authenticator, error) { // Only authenticate GCR and AR so it works with authn.NewMultiKeychain to fallback. if !isGoogle(target.RegistryStr()) { return authn.Anonymous, nil } gk.once.Do(func() { - gk.auth = resolve() + gk.auth = resolve(ctx) }) return gk.auth, nil } -func resolve() authn.Authenticator { - auth, envErr := NewEnvAuthenticator() +func resolve(ctx context.Context) authn.Authenticator { + auth, envErr := NewEnvAuthenticator(ctx) if envErr == nil && auth != authn.Anonymous { logs.Debug.Println("google.Keychain: using Application Default Credentials") return auth } - auth, gErr := NewGcloudAuthenticator() + auth, gErr := NewGcloudAuthenticator(ctx) if gErr == nil && auth != authn.Anonymous { logs.Debug.Println("google.Keychain: using gcloud fallback") return auth diff --git a/pkg/v1/remote/fetcher.go b/pkg/v1/remote/fetcher.go index 4e61002be..d77b37c0c 100644 --- a/pkg/v1/remote/fetcher.go +++ b/pkg/v1/remote/fetcher.go @@ -47,7 +47,7 @@ type fetcher struct { func makeFetcher(ctx context.Context, target resource, o *options) (*fetcher, error) { auth := o.auth if o.keychain != nil { - kauth, err := o.keychain.Resolve(target) + kauth, err := authn.Resolve(ctx, o.keychain, target) if err != nil { return nil, err } diff --git a/pkg/v1/remote/transport/basic.go b/pkg/v1/remote/transport/basic.go index fdb362b76..f2452469d 100644 --- a/pkg/v1/remote/transport/basic.go +++ b/pkg/v1/remote/transport/basic.go @@ -33,7 +33,7 @@ var _ http.RoundTripper = (*basicTransport)(nil) // RoundTrip implements http.RoundTripper func (bt *basicTransport) RoundTrip(in *http.Request) (*http.Response, error) { if bt.auth != authn.Anonymous { - auth, err := bt.auth.Authorization() + auth, err := authn.Authorization(in.Context(), bt.auth) if err != nil { return nil, err } diff --git a/pkg/v1/remote/transport/bearer.go b/pkg/v1/remote/transport/bearer.go index cb1567496..be3bec9c3 100644 --- a/pkg/v1/remote/transport/bearer.go +++ b/pkg/v1/remote/transport/bearer.go @@ -49,7 +49,7 @@ func Exchange(ctx context.Context, reg name.Registry, auth authn.Authenticator, if err != nil { return nil, err } - authcfg, err := auth.Authorization() + authcfg, err := authn.Authorization(ctx, auth) if err != nil { return nil, err } @@ -190,7 +190,7 @@ func (bt *bearerTransport) RoundTrip(in *http.Request) (*http.Response, error) { // The basic token exchange is attempted first, falling back to the oauth flow. // If the IdentityToken is set, this indicates that we should start with the oauth flow. func (bt *bearerTransport) refresh(ctx context.Context) error { - auth, err := bt.basic.Authorization() + auth, err := authn.Authorization(ctx, bt.basic) if err != nil { return err } @@ -295,7 +295,7 @@ func canonicalAddress(host, scheme string) (address string) { // https://docs.docker.com/registry/spec/auth/oauth/ func (bt *bearerTransport) refreshOauth(ctx context.Context) ([]byte, error) { - auth, err := bt.basic.Authorization() + auth, err := authn.Authorization(ctx, bt.basic) if err != nil { return nil, err } diff --git a/pkg/v1/remote/write.go b/pkg/v1/remote/write.go index 04a3989a6..b730dbb05 100644 --- a/pkg/v1/remote/write.go +++ b/pkg/v1/remote/write.go @@ -76,7 +76,7 @@ type writer struct { func makeWriter(ctx context.Context, repo name.Repository, ls []v1.Layer, o *options) (*writer, error) { auth := o.auth if o.keychain != nil { - kauth, err := o.keychain.Resolve(repo) + kauth, err := authn.Resolve(ctx, o.keychain, repo) if err != nil { return nil, err }