Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Context support to auth methods #1949

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmd/crane/cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ $ curl -H "$(crane auth token -H ubuntu)" https://index.docker.io/v2/library/ubu
return err
}

auth, err := o.Keychain.Resolve(repo)
auth, err := authn.Resolve(cmd.Context(), o.Keychain, repo)
if err != nil {
return err
}
Expand Down Expand Up @@ -152,7 +152,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command {
Short: "Implements a credential helper",
Example: eg,
Args: cobra.MaximumNArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
RunE: func(cmd *cobra.Command, args []string) error {
registryAddr := ""
if len(args) == 1 {
registryAddr = args[0]
Expand All @@ -168,7 +168,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command {
if err != nil {
return err
}
authorizer, err := crane.GetOptions(options...).Keychain.Resolve(reg)
authorizer, err := authn.Resolve(cmd.Context(), crane.GetOptions(options...).Keychain, reg)
if err != nil {
return err
}
Expand All @@ -182,7 +182,7 @@ func NewCmdAuthGet(options []crane.Option, argv ...string) *cobra.Command {
os.Exit(1)
}

auth, err := authorizer.Authorization()
auth, err := authn.Authorization(cmd.Context(), authorizer)
if err != nil {
return err
}
Expand Down
17 changes: 17 additions & 0 deletions pkg/authn/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package authn

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Expand All @@ -27,6 +28,22 @@ type Authenticator interface {
Authorization() (*AuthConfig, error)
}

// ContextAuthenticator is like Authenticator, but allows for context to be passed in.
type ContextAuthenticator interface {
// Authorization returns the value to use in an http transport's Authorization header.
AuthorizationContext(context.Context) (*AuthConfig, error)
}

// Authorization calls AuthorizationContext with ctx if the given [Authenticator] implements [ContextAuthenticator],
// otherwise it calls Resolve with the given [Resource].
func Authorization(ctx context.Context, authn Authenticator) (*AuthConfig, error) {
if actx, ok := authn.(ContextAuthenticator); ok {
return actx.AuthorizationContext(ctx)
}

return authn.Authorization()
}

// AuthConfig contains authorization information for connecting to a Registry
// Inlined what we use from github.com/docker/cli/cli/config/types
type AuthConfig struct {
Expand Down
41 changes: 37 additions & 4 deletions pkg/authn/keychain.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package authn

import (
"context"
"os"
"path/filepath"
"sync"
Expand Down Expand Up @@ -45,6 +46,11 @@ type Keychain interface {
Resolve(Resource) (Authenticator, error)
}

// ContextKeychain is like Keychain, but allows for context to be passed in.
type ContextKeychain interface {
ResolveContext(context.Context, Resource) (Authenticator, error)
}

// defaultKeychain implements Keychain with the semantics of the standard Docker
// credential keychain.
type defaultKeychain struct {
Expand All @@ -62,8 +68,23 @@ const (
DefaultAuthKey = "https://" + name.DefaultRegistry + "/v1/"
)

// Resolve implements Keychain.
// Resolve calls ResolveContext with ctx if the given [Keychain] implements [ContextKeychain],
// otherwise it calls Resolve with the given [Resource].
func Resolve(ctx context.Context, keychain Keychain, target Resource) (Authenticator, error) {
if rctx, ok := keychain.(ContextKeychain); ok {
return rctx.ResolveContext(ctx, target)
}

return keychain.Resolve(target)
}

// ResolveContext implements ContextKeychain.
func (dk *defaultKeychain) Resolve(target Resource) (Authenticator, error) {
return dk.ResolveContext(context.Background(), target)
}

// Resolve implements Keychain.
func (dk *defaultKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) {
dk.mu.Lock()
defer dk.mu.Unlock()

Expand Down Expand Up @@ -180,6 +201,10 @@ func NewKeychainFromHelper(h Helper) Keychain { return wrapper{h} }
type wrapper struct{ h Helper }

func (w wrapper) Resolve(r Resource) (Authenticator, error) {
return w.ResolveContext(context.Background(), r)
}

func (w wrapper) ResolveContext(ctx context.Context, r Resource) (Authenticator, error) {
u, p, err := w.h.Get(r.RegistryStr())
if err != nil {
return Anonymous, nil
Expand All @@ -206,8 +231,12 @@ type refreshingKeychain struct {
}

func (r *refreshingKeychain) Resolve(target Resource) (Authenticator, error) {
return r.ResolveContext(context.Background(), target)
}

func (r *refreshingKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) {
last := time.Now()
auth, err := r.keychain.Resolve(target)
auth, err := Resolve(ctx, r.keychain, target)
if err != nil || auth == Anonymous {
return auth, err
}
Expand Down Expand Up @@ -236,17 +265,21 @@ type refreshing struct {
}

func (r *refreshing) Authorization() (*AuthConfig, error) {
return r.AuthorizationContext(context.Background())
}

func (r *refreshing) AuthorizationContext(ctx context.Context) (*AuthConfig, error) {
r.Lock()
defer r.Unlock()
if r.cached == nil || r.expired() {
r.last = r.now()
auth, err := r.keychain.Resolve(r.target)
auth, err := Resolve(ctx, r.keychain, r.target)
if err != nil {
return nil, err
}
r.cached = auth
}
return r.cached.Authorization()
return Authorization(ctx, r.cached)
}

func (r *refreshing) now() time.Time {
Expand Down
8 changes: 7 additions & 1 deletion pkg/authn/multikeychain.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package authn

import "context"

type multiKeychain struct {
keychains []Keychain
}
Expand All @@ -28,8 +30,12 @@ func NewMultiKeychain(kcs ...Keychain) Keychain {

// Resolve implements Keychain.
func (mk *multiKeychain) Resolve(target Resource) (Authenticator, error) {
return mk.ResolveContext(context.Background(), target)
}

func (mk *multiKeychain) ResolveContext(ctx context.Context, target Resource) (Authenticator, error) {
for _, kc := range mk.keychains {
auth, err := kc.Resolve(target)
auth, err := Resolve(ctx, kc, target)
if err != nil {
return nil, err
}
Expand Down
18 changes: 10 additions & 8 deletions pkg/v1/google/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ import (
const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"

// GetGcloudCmd is exposed so we can test this.
var GetGcloudCmd = func() *exec.Cmd {
var GetGcloudCmd = func(ctx context.Context) *exec.Cmd {
// This is odd, but basically what docker-credential-gcr does.
//
// config-helper is undocumented, but it's purportedly the only supported way
// of accessing tokens (`gcloud auth print-access-token` is discouraged).
//
// --force-auth-refresh means we are getting a token that is valid for about
// an hour (we reuse it until it's expired).
return exec.Command("gcloud", "config", "config-helper", "--force-auth-refresh", "--format=json(credential)")
return exec.CommandContext(ctx, "gcloud", "config", "config-helper", "--force-auth-refresh", "--format=json(credential)")
}

// NewEnvAuthenticator returns an authn.Authenticator that generates access
// tokens from the environment we're running in.
//
// See: https://godoc.org/golang.org/x/oauth2/google#FindDefaultCredentials
func NewEnvAuthenticator() (authn.Authenticator, error) {
ts, err := googauth.DefaultTokenSource(context.Background(), cloudPlatformScope)
func NewEnvAuthenticator(ctx context.Context) (authn.Authenticator, error) {
ts, err := googauth.DefaultTokenSource(ctx, cloudPlatformScope)
if err != nil {
return nil, err
}
Expand All @@ -62,14 +62,14 @@ func NewEnvAuthenticator() (authn.Authenticator, error) {

// NewGcloudAuthenticator returns an oauth2.TokenSource that generates access
// tokens by shelling out to the gcloud sdk.
func NewGcloudAuthenticator() (authn.Authenticator, error) {
func NewGcloudAuthenticator(ctx context.Context) (authn.Authenticator, error) {
if _, err := exec.LookPath("gcloud"); err != nil {
// gcloud is not available, fall back to anonymous
logs.Warn.Println("gcloud binary not found")
return authn.Anonymous, nil
}

ts := gcloudSource{GetGcloudCmd}
ts := gcloudSource{ctx, GetGcloudCmd}

// Attempt to fetch a token to ensure gcloud is installed and we can run it.
token, err := ts.Token()
Expand Down Expand Up @@ -143,13 +143,15 @@ type gcloudOutput struct {
}

type gcloudSource struct {
ctx context.Context

// This is passed in so that we mock out gcloud and test Token.
exec func() *exec.Cmd
exec func(ctx context.Context) *exec.Cmd
}

// Token implements oauath2.TokenSource.
func (gs gcloudSource) Token() (*oauth2.Token, error) {
cmd := gs.exec()
cmd := gs.exec(gs.ctx)
var out bytes.Buffer
cmd.Stdout = &out

Expand Down
15 changes: 9 additions & 6 deletions pkg/v1/google/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package google

import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -84,15 +85,16 @@ func TestMain(m *testing.M) {
}
}

func newGcloudCmdMock(env string) func() *exec.Cmd {
return func() *exec.Cmd {
cmd := exec.Command(os.Args[0])
func newGcloudCmdMock(env string) func(context.Context) *exec.Cmd {
return func(ctx context.Context) *exec.Cmd {
cmd := exec.CommandContext(ctx, os.Args[0])
cmd.Env = []string{fmt.Sprintf("GO_TEST_MODE=%s", env)}
return cmd
}
}

func TestGcloudErrors(t *testing.T) {
ctx := context.Background()
cases := []struct {
env string

Expand All @@ -113,7 +115,7 @@ func TestGcloudErrors(t *testing.T) {
t.Run(tc.env, func(t *testing.T) {
GetGcloudCmd = newGcloudCmdMock(tc.env)

if _, err := NewGcloudAuthenticator(); err == nil {
if _, err := NewGcloudAuthenticator(ctx); err == nil {
t.Errorf("wanted error, got nil")
} else if got := err.Error(); !strings.HasPrefix(got, tc.wantPrefix) {
t.Errorf("wanted error prefix %q, got %q", tc.wantPrefix, got)
Expand All @@ -123,13 +125,14 @@ func TestGcloudErrors(t *testing.T) {
}

func TestGcloudSuccess(t *testing.T) {
ctx := context.Background()
// Stupid coverage to make sure it doesn't panic.
var b bytes.Buffer
logs.Debug.SetOutput(&b)

GetGcloudCmd = newGcloudCmdMock("success")

auth, err := NewGcloudAuthenticator()
auth, err := NewGcloudAuthenticator(ctx)
if err != nil {
t.Fatalf("NewGcloudAuthenticator got error %v", err)
}
Expand Down Expand Up @@ -263,7 +266,7 @@ func TestNewEnvAuthenticatorFailure(t *testing.T) {
}

// Expect error.
_, err := NewEnvAuthenticator()
_, err := NewEnvAuthenticator(context.Background())
if err == nil {
t.Errorf("expected err, got nil")
}
Expand Down
14 changes: 10 additions & 4 deletions pkg/v1/google/keychain.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package google

import (
"context"
"strings"
"sync"

Expand Down Expand Up @@ -52,26 +53,31 @@ type googleKeychain struct {
// In general, we don't worry about that here because we expect to use the same
// gcloud configuration in the scope of this one process.
func (gk *googleKeychain) Resolve(target authn.Resource) (authn.Authenticator, error) {
return gk.ResolveContext(context.Background(), target)
}

// ResolveContext implements authn.ContextKeychain.
func (gk *googleKeychain) ResolveContext(ctx context.Context, target authn.Resource) (authn.Authenticator, error) {
// Only authenticate GCR and AR so it works with authn.NewMultiKeychain to fallback.
if !isGoogle(target.RegistryStr()) {
return authn.Anonymous, nil
}

gk.once.Do(func() {
gk.auth = resolve()
gk.auth = resolve(ctx)
})

return gk.auth, nil
}

func resolve() authn.Authenticator {
auth, envErr := NewEnvAuthenticator()
func resolve(ctx context.Context) authn.Authenticator {
auth, envErr := NewEnvAuthenticator(ctx)
if envErr == nil && auth != authn.Anonymous {
logs.Debug.Println("google.Keychain: using Application Default Credentials")
return auth
}

auth, gErr := NewGcloudAuthenticator()
auth, gErr := NewGcloudAuthenticator(ctx)
if gErr == nil && auth != authn.Anonymous {
logs.Debug.Println("google.Keychain: using gcloud fallback")
return auth
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type fetcher struct {
func makeFetcher(ctx context.Context, target resource, o *options) (*fetcher, error) {
auth := o.auth
if o.keychain != nil {
kauth, err := o.keychain.Resolve(target)
kauth, err := authn.Resolve(ctx, o.keychain, target)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/transport/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var _ http.RoundTripper = (*basicTransport)(nil)
// RoundTrip implements http.RoundTripper
func (bt *basicTransport) RoundTrip(in *http.Request) (*http.Response, error) {
if bt.auth != authn.Anonymous {
auth, err := bt.auth.Authorization()
auth, err := authn.Authorization(in.Context(), bt.auth)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/v1/remote/transport/bearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func Exchange(ctx context.Context, reg name.Registry, auth authn.Authenticator,
if err != nil {
return nil, err
}
authcfg, err := auth.Authorization()
authcfg, err := authn.Authorization(ctx, auth)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -190,7 +190,7 @@ func (bt *bearerTransport) RoundTrip(in *http.Request) (*http.Response, error) {
// The basic token exchange is attempted first, falling back to the oauth flow.
// If the IdentityToken is set, this indicates that we should start with the oauth flow.
func (bt *bearerTransport) refresh(ctx context.Context) error {
auth, err := bt.basic.Authorization()
auth, err := authn.Authorization(ctx, bt.basic)
if err != nil {
return err
}
Expand Down Expand Up @@ -295,7 +295,7 @@ func canonicalAddress(host, scheme string) (address string) {

// https://docs.docker.com/registry/spec/auth/oauth/
func (bt *bearerTransport) refreshOauth(ctx context.Context) ([]byte, error) {
auth, err := bt.basic.Authorization()
auth, err := authn.Authorization(ctx, bt.basic)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading