Skip to content

Commit

Permalink
Add remote.WithContext (#744)
Browse files Browse the repository at this point in the history
This allows passing a context.Context into any remote operations.
  • Loading branch information
jonjohnsonjr committed Jul 17, 2020
1 parent 72597da commit a849933
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 50 deletions.
11 changes: 10 additions & 1 deletion pkg/v1/remote/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ func CatalogPage(target name.Registry, last string, n int, options ...Option) ([
}

client := http.Client{Transport: tr}
resp, err := client.Get(uri.String())
req, err := http.NewRequest(http.MethodGet, uri.String(), nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req.WithContext(o.context))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -92,6 +96,11 @@ func Catalog(ctx context.Context, target name.Registry, options ...Option) ([]st

client := http.Client{Transport: tr}

// WithContext overrides the ctx passed directly.
if o.context != context.Background() {
ctx = o.context
}

var (
parsed catalog
repoList []string
Expand Down
6 changes: 4 additions & 2 deletions pkg/v1/remote/check.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package remote

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -34,8 +35,9 @@ func CheckPushPermission(ref name.Reference, kc authn.Keychain, t http.RoundTrip
// authorize a push. Figure out how to return early here when we can,
// to avoid a roundtrip for spec-compliant registries.
w := writer{
repo: ref.Context(),
client: &http.Client{Transport: tr},
repo: ref.Context(),
client: &http.Client{Transport: tr},
context: context.Background(),
}
loc, _, err := w.initiateUpload("", "")
if loc != "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/v1/remote/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func Delete(ref name.Reference, options ...Option) error {
return err
}

resp, err := c.Do(req)
resp, err := c.Do(req.WithContext(o.context))
if err != nil {
return err
}
Expand Down
37 changes: 22 additions & 15 deletions pkg/v1/remote/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package remote

import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -168,10 +169,7 @@ func (d *Descriptor) ImageIndex() (v1.ImageIndex, error) {

func (d *Descriptor) remoteImage() *remoteImage {
return &remoteImage{
fetcher: fetcher{
Ref: d.Ref,
Client: d.Client,
},
fetcher: d.fetcher,
manifest: d.Manifest,
mediaType: d.MediaType,
descriptor: &d.Descriptor,
Expand All @@ -180,10 +178,7 @@ func (d *Descriptor) remoteImage() *remoteImage {

func (d *Descriptor) remoteIndex() *remoteIndex {
return &remoteIndex{
fetcher: fetcher{
Ref: d.Ref,
Client: d.Client,
},
fetcher: d.fetcher,
manifest: d.Manifest,
mediaType: d.MediaType,
descriptor: &d.Descriptor,
Expand All @@ -192,8 +187,9 @@ func (d *Descriptor) remoteIndex() *remoteIndex {

// fetcher implements methods for reading from a registry.
type fetcher struct {
Ref name.Reference
Client *http.Client
Ref name.Reference
Client *http.Client
context context.Context
}

func makeFetcher(ref name.Reference, o *options) (*fetcher, error) {
Expand All @@ -202,8 +198,9 @@ func makeFetcher(ref name.Reference, o *options) (*fetcher, error) {
return nil, err
}
return &fetcher{
Ref: ref,
Client: &http.Client{Transport: tr},
Ref: ref,
Client: &http.Client{Transport: tr},
context: o.context,
}, nil
}

Expand All @@ -228,7 +225,7 @@ func (f *fetcher) fetchManifest(ref name.Reference, acceptable []types.MediaType
}
req.Header.Set("Accept", strings.Join(accept, ","))

resp, err := f.Client.Do(req)
resp, err := f.Client.Do(req.WithContext(f.context))
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -282,7 +279,12 @@ func (f *fetcher) fetchManifest(ref name.Reference, acceptable []types.MediaType

func (f *fetcher) fetchBlob(h v1.Hash) (io.ReadCloser, error) {
u := f.url("blobs", h.String())
resp, err := f.Client.Get(u.String())
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}

resp, err := f.Client.Do(req.WithContext(f.context))
if err != nil {
return nil, err
}
Expand All @@ -297,7 +299,12 @@ func (f *fetcher) fetchBlob(h v1.Hash) (io.ReadCloser, error) {

func (f *fetcher) headBlob(h v1.Hash) (*http.Response, error) {
u := f.url("blobs", h.String())
resp, err := f.Client.Head(u.String())
req, err := http.NewRequest(http.MethodHead, u.String(), nil)
if err != nil {
return nil, err
}

resp, err := f.Client.Do(req.WithContext(f.context))
if err != nil {
return nil, err
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/v1/remote/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ func (rl *remoteImageLayer) Compressed() (io.ReadCloser, error) {
// TODO: Maybe we don't want to try pulling from the registry first?
var lastErr error
for _, u := range urls {
resp, err := rl.ri.Client.Get(u.String())
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}

resp, err := rl.ri.Client.Do(req.WithContext(rl.ri.context))
if err != nil {
lastErr = err
continue
Expand Down
21 changes: 13 additions & 8 deletions pkg/v1/remote/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package remote

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -178,8 +179,9 @@ func TestRawManifestDigests(t *testing.T) {

rmt := remoteImage{
fetcher: fetcher{
Ref: ref,
Client: http.DefaultClient,
Ref: ref,
Client: http.DefaultClient,
context: context.Background(),
},
}

Expand Down Expand Up @@ -212,8 +214,9 @@ func TestRawManifestNotFound(t *testing.T) {

img := remoteImage{
fetcher: fetcher{
Ref: mustNewTag(t, fmt.Sprintf("%s/%s:latest", u.Host, expectedRepo)),
Client: http.DefaultClient,
Ref: mustNewTag(t, fmt.Sprintf("%s/%s:latest", u.Host, expectedRepo)),
Client: http.DefaultClient,
context: context.Background(),
},
}

Expand Down Expand Up @@ -251,8 +254,9 @@ func TestRawConfigFileNotFound(t *testing.T) {

rmt := remoteImage{
fetcher: fetcher{
Ref: mustNewTag(t, fmt.Sprintf("%s/%s:latest", u.Host, expectedRepo)),
Client: http.DefaultClient,
Ref: mustNewTag(t, fmt.Sprintf("%s/%s:latest", u.Host, expectedRepo)),
Client: http.DefaultClient,
context: context.Background(),
},
}

Expand Down Expand Up @@ -291,8 +295,9 @@ func TestAcceptHeaders(t *testing.T) {

rmt := &remoteImage{
fetcher: fetcher{
Ref: mustNewTag(t, fmt.Sprintf("%s/%s:latest", u.Host, expectedRepo)),
Client: http.DefaultClient,
Ref: mustNewTag(t, fmt.Sprintf("%s/%s:latest", u.Host, expectedRepo)),
Client: http.DefaultClient,
context: context.Background(),
},
}
manifest, err := rmt.RawManifest()
Expand Down
5 changes: 3 additions & 2 deletions pkg/v1/remote/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ func (r *remoteIndex) childDescriptor(child v1.Descriptor, platform v1.Platform)
}
return &Descriptor{
fetcher: fetcher{
Ref: ref,
Client: r.Client,
Ref: ref,
Client: r.Client,
context: r.context,
},
Manifest: manifest,
Descriptor: child,
Expand Down
6 changes: 4 additions & 2 deletions pkg/v1/remote/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package remote

import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -140,8 +141,9 @@ func TestIndexRawManifestDigests(t *testing.T) {

rmt := remoteIndex{
fetcher: fetcher{
Ref: ref,
Client: http.DefaultClient,
Ref: ref,
Client: http.DefaultClient,
context: context.Background(),
},
}

Expand Down
7 changes: 7 additions & 0 deletions pkg/v1/remote/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ func ListWithContext(ctx context.Context, repo name.Repository, options ...Optio
RawQuery: "n=1000",
}

// This is lazy, but I want to make sure List(..., WithContext(ctx)) works
// without calling makeOptions() twice (which can have side effects).
// This means ListWithContext(ctx, ..., WithContext(ctx2)) prefers ctx2.
if o.context != context.Background() {
ctx = o.context
}

client := http.Client{Transport: tr}
tagList := []string{}
parsed := tags{}
Expand Down
17 changes: 17 additions & 0 deletions pkg/v1/remote/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package remote

import (
"context"
"net/http"

"github.com/google/go-containerregistry/pkg/authn"
Expand All @@ -31,13 +32,15 @@ type options struct {
keychain authn.Keychain
transport http.RoundTripper
platform v1.Platform
context context.Context
}

func makeOptions(target authn.Resource, opts ...Option) (*options, error) {
o := &options{
auth: authn.Anonymous,
transport: http.DefaultTransport,
platform: defaultPlatform,
context: context.Background(),
}

for _, option := range opts {
Expand Down Expand Up @@ -114,3 +117,17 @@ func WithPlatform(p v1.Platform) Option {
return nil
}
}

// WithContext is a functional option for setting the context in http requests
// performed by a given function. Note that this context is used for _all_
// http requests, not just the initial volley. E.g., for remote.Image, the
// context will be set on http requests generated by subsequent calls to
// RawConfigFile() and even methods on layers returned by Layers().
//
// The default context is context.Background().
func WithContext(ctx context.Context) Option {
return func(o *options) error {
o.context = ctx
return nil
}
}
Loading

0 comments on commit a849933

Please sign in to comment.