From de6223d0ad6d59136036e659191651d5f8a3597b Mon Sep 17 00:00:00 2001 From: Jason Hall Date: Fri, 11 Jun 2021 14:24:45 -0400 Subject: [PATCH] Verify size in verify.ReadCloser (#1044) * Verify size in verify.ReadCloser * use resp.ContentLength * use resp.ContentLength more better --- internal/verify/verify.go | 24 +++++++++++++++++------- internal/verify/verify_test.go | 25 ++++++++++++++++++++++--- pkg/v1/remote/descriptor.go | 19 ++++++++++--------- pkg/v1/remote/image.go | 2 +- 4 files changed, 50 insertions(+), 20 deletions(-) diff --git a/internal/verify/verify.go b/internal/verify/verify.go index 444680380..46308dc58 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -27,19 +27,24 @@ import ( ) type verifyReader struct { - inner io.Reader - hasher hash.Hash - expected v1.Hash + inner io.Reader + hasher hash.Hash + expected v1.Hash + gotSize, wantSize int64 } // Read implements io.Reader func (vc *verifyReader) Read(b []byte) (int, error) { n, err := vc.inner.Read(b) + vc.gotSize += int64(n) if err == io.EOF { + if vc.gotSize != vc.wantSize { + return n, fmt.Errorf("error verifying size; got %d, want %d", vc.gotSize, vc.wantSize) + } got := hex.EncodeToString(vc.hasher.Sum(make([]byte, 0, vc.hasher.Size()))) if want := vc.expected.Hex; got != want { - return n, fmt.Errorf("error verifying %s checksum; got %q, want %q", - vc.expected.Algorithm, got, want) + return n, fmt.Errorf("error verifying %s checksum after reading %d bytes; got %q, want %q", + vc.expected.Algorithm, vc.gotSize, got, want) } } return n, err @@ -47,17 +52,22 @@ func (vc *verifyReader) Read(b []byte) (int, error) { // ReadCloser wraps the given io.ReadCloser to verify that its contents match // the provided v1.Hash before io.EOF is returned. -func ReadCloser(r io.ReadCloser, h v1.Hash) (io.ReadCloser, error) { +// +// The reader will only be read up to size bytes, to prevent resource +// exhaustion. If EOF is returned before size bytes are read, an error is +// returned. +func ReadCloser(r io.ReadCloser, size int64, h v1.Hash) (io.ReadCloser, error) { w, err := v1.Hasher(h.Algorithm) if err != nil { return nil, err } - r2 := io.TeeReader(r, w) + r2 := io.LimitReader(io.TeeReader(r, w), size) return &and.ReadCloser{ Reader: &verifyReader{ inner: r2, hasher: w, expected: h, + wantSize: size, }, CloseFunc: r.Close, }, nil diff --git a/internal/verify/verify_test.go b/internal/verify/verify_test.go index 155fc4cec..f4087f832 100644 --- a/internal/verify/verify_test.go +++ b/internal/verify/verify_test.go @@ -16,6 +16,7 @@ package verify import ( "bytes" + "fmt" "io/ioutil" "strings" "testing" @@ -35,7 +36,7 @@ func TestVerificationFailure(t *testing.T) { want := "This is the input string." buf := bytes.NewBufferString(want) - verified, err := ReadCloser(ioutil.NopCloser(buf), mustHash("not the same", t)) + verified, err := ReadCloser(ioutil.NopCloser(buf), int64(len(want)), mustHash("not the same", t)) if err != nil { t.Fatal("ReadCloser() =", err) } @@ -48,7 +49,7 @@ func TestVerification(t *testing.T) { want := "This is the input string." buf := bytes.NewBufferString(want) - verified, err := ReadCloser(ioutil.NopCloser(buf), mustHash(want, t)) + verified, err := ReadCloser(ioutil.NopCloser(buf), int64(len(want)), mustHash(want, t)) if err != nil { t.Fatal("ReadCloser() =", err) } @@ -62,8 +63,26 @@ func TestBadHash(t *testing.T) { Algorithm: "fake256", Hex: "whatever", } - _, err := ReadCloser(ioutil.NopCloser(strings.NewReader("hi")), h) + _, err := ReadCloser(ioutil.NopCloser(strings.NewReader("hi")), 0, h) if err == nil { t.Errorf("ReadCloser() = %v, wanted err", err) } } + +func TestBadSize(t *testing.T) { + want := "This is the input string." + + // having too much content or expecting too much content returns an error. + for _, size := range []int64{3, 100} { + t.Run(fmt.Sprintf("expecting size %d", size), func(t *testing.T) { + buf := bytes.NewBufferString(want) + rc, err := ReadCloser(ioutil.NopCloser(buf), size, mustHash(want, t)) + if err != nil { + t.Fatal("ReadCloser() =", err) + } + if b, err := ioutil.ReadAll(rc); err == nil { + t.Errorf("ReadAll() = %q; want verification error", string(b)) + } + }) + } +} diff --git a/pkg/v1/remote/descriptor.go b/pkg/v1/remote/descriptor.go index a13f01b68..411c096c7 100644 --- a/pkg/v1/remote/descriptor.go +++ b/pkg/v1/remote/descriptor.go @@ -22,7 +22,6 @@ import ( "io/ioutil" "net/http" "net/url" - "strconv" "strings" "github.com/google/go-containerregistry/internal/verify" @@ -330,13 +329,9 @@ func (f *fetcher) headManifest(ref name.Reference, acceptable []types.MediaType) } mediaType := types.MediaType(mth) - lh := resp.Header.Get("Content-Length") - if lh == "" { - return nil, fmt.Errorf("HEAD %s: response did not include Content-Length header", u.String()) - } - size, err := strconv.ParseInt(lh, 10, 64) - if err != nil { - return nil, err + size := resp.ContentLength + if size == -1 { + return nil, fmt.Errorf("GET %s: response did not include Content-Length header", u.String()) } dh := resp.Header.Get("Docker-Content-Digest") @@ -380,7 +375,13 @@ func (f *fetcher) fetchBlob(ctx context.Context, h v1.Hash) (io.ReadCloser, erro return nil, err } - return verify.ReadCloser(resp.Body, h) + // Verify up to the content-length header value. + size := resp.ContentLength + if size == -1 { + return nil, fmt.Errorf("GET %s: response did not include Content-Length header", u.String()) + } + + return verify.ReadCloser(resp.Body, size, h) } func (f *fetcher) headBlob(h v1.Hash) (*http.Response, error) { diff --git a/pkg/v1/remote/image.go b/pkg/v1/remote/image.go index 71739fee3..c5dd84182 100644 --- a/pkg/v1/remote/image.go +++ b/pkg/v1/remote/image.go @@ -177,7 +177,7 @@ func (rl *remoteImageLayer) Compressed() (io.ReadCloser, error) { continue } - return verify.ReadCloser(resp.Body, rl.digest) + return verify.ReadCloser(resp.Body, d.Size, rl.digest) } return nil, lastErr