Skip to content

Commit

Permalink
Plumb contexts through transport.New (#834)
Browse files Browse the repository at this point in the history
Add transport.NewWithContext.

Use the provided context for the initial ping request and token
exchange.

Subsequent token exchanges (to refresh) will use the context of the
incoming request in RoundTrip.

Drop the hardcoded ping timeout.

Use transport.NewWithContext everywhere that's easily plumbable.
  • Loading branch information
jonjohnsonjr committed Nov 18, 2020
1 parent 144defc commit 3904ad8
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 32 deletions.
2 changes: 2 additions & 0 deletions pkg/internal/legacy/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func CopySchema1(desc *remote.Descriptor, srcRef, dstRef name.Reference, srcAuth
func putManifest(desc *remote.Descriptor, dstRef name.Reference, dstAuth authn.Authenticator) error {
reg := dstRef.Context().Registry
scopes := []string{dstRef.Scope(transport.PushScope)}

// TODO(jonjohnsonjr): Use NewWithContext.
tr, err := transport.New(reg, dstAuth, http.DefaultTransport, scopes)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/google/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func newLister(repo name.Repository, options ...ListerOption) (*lister, error) {
l.transport = transport.NewRetry(l.transport)

scopes := []string{repo.Scope(transport.PullScope)}
tr, err := transport.New(repo.Registry, l.auth, l.transport, scopes)
tr, err := transport.NewWithContext(l.ctx, repo.Registry, l.auth, l.transport, scopes)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/v1/remote/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func CatalogPage(target name.Registry, last string, n int, options ...Option) ([
}

scopes := []string{target.Scope(transport.PullScope)}
tr, err := transport.New(target, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, target, o.auth, o.transport, scopes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -82,7 +82,7 @@ func Catalog(ctx context.Context, target name.Registry, options ...Option) ([]st
}

scopes := []string{target.Scope(transport.PullScope)}
tr, err := transport.New(target, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, target, o.auth, o.transport, scopes)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func Delete(ref name.Reference, options ...Option) error {
return err
}
scopes := []string{ref.Scope(transport.DeleteScope)}
tr, err := transport.New(ref.Context().Registry, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, ref.Context().Registry, o.auth, o.transport, scopes)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ type fetcher struct {
}

func makeFetcher(ref name.Reference, o *options) (*fetcher, error) {
tr, err := transport.New(ref.Context().Registry, o.auth, o.transport, []string{ref.Scope(transport.PullScope)})
tr, err := transport.NewWithContext(o.context, ref.Context().Registry, o.auth, o.transport, []string{ref.Scope(transport.PullScope)})
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ListWithContext(ctx context.Context, repo name.Repository, options ...Optio
return nil, err
}
scopes := []string{repo.Scope(transport.PullScope)}
tr, err := transport.New(repo.Registry, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, repo.Registry, o.auth, o.transport, scopes)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/multi_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func MultiWrite(m map[name.Reference]Taggable, options ...Option) error {
ls = append(ls, l)
}
scopes := scopesForUploadingImage(repo, ls)
tr, err := transport.New(repo.Registry, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, repo.Registry, o.auth, o.transport, scopes)
if err != nil {
return err
}
Expand Down
19 changes: 10 additions & 9 deletions pkg/v1/remote/transport/bearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package transport

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -75,7 +76,7 @@ func (bt *bearerTransport) RoundTrip(in *http.Request) (*http.Response, error) {

// Perform a token refresh() and retry the request in case the token has expired
if res.StatusCode == http.StatusUnauthorized {
if err = bt.refresh(); err != nil {
if err = bt.refresh(in.Context()); err != nil {
return nil, err
}
return sendRequest()
Expand All @@ -88,7 +89,7 @@ func (bt *bearerTransport) RoundTrip(in *http.Request) (*http.Response, error) {
// so we rely on heuristics and fallbacks to support as many registries as possible.
// The basic token exchange is attempted first, falling back to the oauth flow.
// If the IdentityToken is set, this indicates that we should start with the oauth flow.
func (bt *bearerTransport) refresh() error {
func (bt *bearerTransport) refresh(ctx context.Context) error {
auth, err := bt.basic.Authorization()
if err != nil {
return err
Expand All @@ -104,15 +105,15 @@ func (bt *bearerTransport) refresh() error {
// If the secret being stored is an identity token,
// the Username should be set to <token>, which indicates
// we are using an oauth flow.
content, err = bt.refreshOauth()
content, err = bt.refreshOauth(ctx)
if terr, ok := err.(*Error); ok && terr.StatusCode == http.StatusNotFound {
// Note: Not all token servers implement oauth2.
// If the request to the endpoint returns 404 using the HTTP POST method,
// refer to Token Documentation for using the HTTP GET method supported by all token servers.
content, err = bt.refreshBasic()
content, err = bt.refreshBasic(ctx)
}
} else {
content, err = bt.refreshBasic()
content, err = bt.refreshBasic(ctx)
}
if err != nil {
return err
Expand Down Expand Up @@ -186,7 +187,7 @@ func canonicalAddress(host, scheme string) (address string) {
}

// https://docs.docker.com/registry/spec/auth/oauth/
func (bt *bearerTransport) refreshOauth() ([]byte, error) {
func (bt *bearerTransport) refreshOauth(ctx context.Context) ([]byte, error) {
auth, err := bt.basic.Authorization()
if err != nil {
return nil, err
Expand Down Expand Up @@ -220,7 +221,7 @@ func (bt *bearerTransport) refreshOauth() ([]byte, error) {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

// We don't want to log credentials.
ctx := redact.NewContext(req.Context(), "oauth token response contains credentials")
ctx = redact.NewContext(ctx, "oauth token response contains credentials")

resp, err := client.Do(req.WithContext(ctx))
if err != nil {
Expand All @@ -236,7 +237,7 @@ func (bt *bearerTransport) refreshOauth() ([]byte, error) {
}

// https://docs.docker.com/registry/spec/auth/token/
func (bt *bearerTransport) refreshBasic() ([]byte, error) {
func (bt *bearerTransport) refreshBasic(ctx context.Context) ([]byte, error) {
u, err := url.Parse(bt.realm)
if err != nil {
return nil, err
Expand All @@ -259,7 +260,7 @@ func (bt *bearerTransport) refreshBasic() ([]byte, error) {
}

// We don't want to log credentials.
ctx := redact.NewContext(req.Context(), "basic token response contains credentials")
ctx = redact.NewContext(ctx, "basic token response contains credentials")

resp, err := client.Do(req.WithContext(ctx))
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion pkg/v1/remote/transport/bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package transport

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -79,7 +80,7 @@ func TestBearerRefresh(t *testing.T) {
scheme: "http",
}

if err := bt.refresh(); (err != nil) != tc.wantErr {
if err := bt.refresh(context.Background()); (err != nil) != tc.wantErr {
t.Errorf("refresh() = %v", err)
}
})
Expand Down
12 changes: 8 additions & 4 deletions pkg/v1/remote/transport/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
package transport

import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"time"

"github.com/google/go-containerregistry/pkg/name"
)
Expand Down Expand Up @@ -66,8 +66,8 @@ func parseChallenge(suffix string) map[string]string {
return kv
}

func ping(reg name.Registry, t http.RoundTripper) (*pingResp, error) {
client := http.Client{Transport: t, Timeout: 120 * time.Second}
func ping(ctx context.Context, reg name.Registry, t http.RoundTripper) (*pingResp, error) {
client := http.Client{Transport: t}

// This first attempts to use "https" for every request, falling back to http
// if the registry matches our localhost heuristic or if it is intentionally
Expand All @@ -80,7 +80,11 @@ func ping(reg name.Registry, t http.RoundTripper) (*pingResp, error) {
var connErr error
for _, scheme := range schemes {
url := fmt.Sprintf("%s://%s/v2/", scheme, reg.Name())
resp, err := client.Get(url)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
connErr = err
// Potentially retry with http.
Expand Down
11 changes: 6 additions & 5 deletions pkg/v1/remote/transport/ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package transport

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -83,7 +84,7 @@ func TestPingNoChallenge(t *testing.T) {
},
}

pr, err := ping(testRegistry, tprt)
pr, err := ping(context.Background(), testRegistry, tprt)
if err != nil {
t.Errorf("ping() = %v", err)
}
Expand All @@ -108,7 +109,7 @@ func TestPingBasicChallengeNoParams(t *testing.T) {
},
}

pr, err := ping(testRegistry, tprt)
pr, err := ping(context.Background(), testRegistry, tprt)
if err != nil {
t.Errorf("ping() = %v", err)
}
Expand All @@ -133,7 +134,7 @@ func TestPingBearerChallengeWithParams(t *testing.T) {
},
}

pr, err := ping(testRegistry, tprt)
pr, err := ping(context.Background(), testRegistry, tprt)
if err != nil {
t.Errorf("ping() = %v", err)
}
Expand All @@ -158,7 +159,7 @@ func TestUnsupportedStatus(t *testing.T) {
},
}

pr, err := ping(testRegistry, tprt)
pr, err := ping(context.Background(), testRegistry, tprt)
if err == nil {
t.Errorf("ping() = %v", pr)
}
Expand Down Expand Up @@ -194,7 +195,7 @@ func TestPingHttpFallback(t *testing.T) {
}

for _, test := range tests {
pr, err := ping(test.reg, tprt)
pr, err := ping(context.Background(), test.reg, tprt)
if err == nil {
t.Errorf("ping() = %v", pr)
}
Expand Down
14 changes: 12 additions & 2 deletions pkg/v1/remote/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package transport

import (
"context"
"fmt"
"net/http"

Expand All @@ -25,7 +26,16 @@ import (
// New returns a new RoundTripper based on the provided RoundTripper that has been
// setup to authenticate with the remote registry "reg", in the capacity
// laid out by the specified scopes.
//
// TODO(jonjohnsonjr): Deprecate this.
func New(reg name.Registry, auth authn.Authenticator, t http.RoundTripper, scopes []string) (http.RoundTripper, error) {
return NewWithContext(context.Background(), reg, auth, t, scopes)
}

// NewWithContext returns a new RoundTripper based on the provided RoundTripper that has been
// setup to authenticate with the remote registry "reg", in the capacity
// laid out by the specified scopes.
func NewWithContext(ctx context.Context, reg name.Registry, auth authn.Authenticator, t http.RoundTripper, scopes []string) (http.RoundTripper, error) {
// The handshake:
// 1. Use "t" to ping() the registry for the authentication challenge.
//
Expand All @@ -40,7 +50,7 @@ func New(reg name.Registry, auth authn.Authenticator, t http.RoundTripper, scope

// First we ping the registry to determine the parameters of the authentication handshake
// (if one is even necessary).
pr, err := ping(reg, t)
pr, err := ping(ctx, reg, t)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -81,7 +91,7 @@ func New(reg name.Registry, auth authn.Authenticator, t http.RoundTripper, scope
scopes: scopes,
scheme: pr.scheme,
}
if err := bt.refresh(); err != nil {
if err := bt.refresh(ctx); err != nil {
return nil, err
}
return bt, nil
Expand Down
8 changes: 4 additions & 4 deletions pkg/v1/remote/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func Write(ref name.Reference, img v1.Image, options ...Option) error {
}

scopes := scopesForUploadingImage(ref.Context(), ls)
tr, err := transport.New(ref.Context().Registry, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, ref.Context().Registry, o.auth, o.transport, scopes)
if err != nil {
return err
}
Expand Down Expand Up @@ -549,7 +549,7 @@ func WriteIndex(ref name.Reference, ii v1.ImageIndex, options ...Option) error {
return err
}
scopes := []string{ref.Scope(transport.PushScope)}
tr, err := transport.New(ref.Context().Registry, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, ref.Context().Registry, o.auth, o.transport, scopes)
if err != nil {
return err
}
Expand All @@ -568,7 +568,7 @@ func WriteLayer(repo name.Repository, layer v1.Layer, options ...Option) error {
return err
}
scopes := scopesForUploadingImage(repo, []v1.Layer{layer})
tr, err := transport.New(repo.Registry, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, repo.Registry, o.auth, o.transport, scopes)
if err != nil {
return err
}
Expand All @@ -595,7 +595,7 @@ func Tag(tag name.Tag, t Taggable, options ...Option) error {
// * Allow callers to pass in a transport.Transport, typecheck
// it to allow them to reuse the transport across multiple calls.
// * WithTag option to do multiple manifest PUTs in commitManifest.
tr, err := transport.New(tag.Context().Registry, o.auth, o.transport, scopes)
tr, err := transport.NewWithContext(o.context, tag.Context().Registry, o.auth, o.transport, scopes)
if err != nil {
return err
}
Expand Down

0 comments on commit 3904ad8

Please sign in to comment.