diff --git a/pkg/authn/keychain.go b/pkg/authn/keychain.go index a4a88b3d5..4e32500cd 100644 --- a/pkg/authn/keychain.go +++ b/pkg/authn/keychain.go @@ -18,6 +18,7 @@ import ( "os" "path/filepath" "sync" + "time" "github.com/docker/cli/cli/config" "github.com/docker/cli/cli/config/configfile" @@ -52,7 +53,7 @@ type defaultKeychain struct { var ( // DefaultKeychain implements Keychain by interpreting the docker config file. - DefaultKeychain Keychain = &defaultKeychain{} + DefaultKeychain = RefreshingKeychain(&defaultKeychain{}, 5*time.Minute) ) const ( @@ -178,3 +179,71 @@ func (w wrapper) Resolve(r Resource) (Authenticator, error) { } return FromConfig(AuthConfig{Username: u, Password: p}), nil } + +func RefreshingKeychain(inner Keychain, duration time.Duration) Keychain { + return &refreshingKeychain{ + keychain: inner, + duration: duration, + } +} + +type refreshingKeychain struct { + keychain Keychain + duration time.Duration + clock func() time.Time +} + +func (r *refreshingKeychain) Resolve(target Resource) (Authenticator, error) { + last := time.Now() + auth, err := r.keychain.Resolve(target) + if err != nil || auth == Anonymous { + return auth, err + } + return &refreshing{ + target: target, + keychain: r.keychain, + last: last, + cached: auth, + duration: r.duration, + clock: r.clock, + }, nil +} + +type refreshing struct { + sync.Mutex + target Resource + keychain Keychain + + duration time.Duration + + last time.Time + cached Authenticator + + // for testing + clock func() time.Time +} + +func (r *refreshing) Authorization() (*AuthConfig, error) { + r.Lock() + defer r.Unlock() + if r.cached == nil || r.expired() { + r.last = r.now() + auth, err := r.keychain.Resolve(r.target) + if err != nil { + return nil, err + } + r.cached = auth + } + return r.cached.Authorization() +} + +func (r *refreshing) now() time.Time { + if r.clock == nil { + return time.Now() + } + return r.clock() +} + +func (r *refreshing) expired() bool { + return r.now().Sub(r.last) > r.duration +} diff --git a/pkg/authn/keychain_test.go b/pkg/authn/keychain_test.go index 9dfbad14a..f983a4dec 100644 --- a/pkg/authn/keychain_test.go +++ b/pkg/authn/keychain_test.go @@ -24,6 +24,7 @@ import ( "path/filepath" "reflect" "testing" + "time" "github.com/google/go-containerregistry/pkg/name" ) @@ -390,3 +391,55 @@ func TestConfigFileIsADir(t *testing.T) { t.Errorf("expected Anonymous, got %v", auth) } } + +type fakeKeychain struct { + auth Authenticator + err error + + count int +} + +func (k *fakeKeychain) Resolve(target Resource) (Authenticator, error) { + k.count++ + return k.auth, k.err +} + +func TestRefreshingAuth(t *testing.T) { + repo := name.MustParseReference("example.com/my/repo").Context() + last := time.Now() + + // Increments by 1 minute each invocation. + clock := func() time.Time { + last = last.Add(1 * time.Minute) + return last + } + + want := AuthConfig{ + Username: "foo", + Password: "secret", + } + + keychain := &fakeKeychain{FromConfig(want), nil, 0} + rk := RefreshingKeychain(keychain, 5*time.Minute) + rk.(*refreshingKeychain).clock = clock + + auth, err := rk.Resolve(repo) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + got, err := auth.Authorization() + if err != nil { + t.Fatal(err) + } + + if *got != want { + t.Errorf("got %+v, want %+v", got, want) + } + } + + if got, want := keychain.count, 2; got != want { + t.Errorf("refreshed %d times, wanted %d", got, want) + } +}