diff --git a/pkg/chunked/compression_linux.go b/pkg/chunked/compression_linux.go index 2dac463543..c15d63250b 100644 --- a/pkg/chunked/compression_linux.go +++ b/pkg/chunked/compression_linux.go @@ -48,21 +48,21 @@ func readEstargzChunkedManifest(blobStream ImageSourceSeekable, blobSize int64, Offset: uint64(blobSize - footerSize), Length: uint64(footerSize), } - parts, errs, err := blobStream.GetBlobAt([]ImageSourceChunk{chunk}) + + footer := make([]byte, footerSize) + streamsOrErrors, err := getBlobAt(blobStream, []ImageSourceChunk{chunk}, 1) if err != nil { return nil, 0, err } - var reader io.ReadCloser - select { - case r := <-parts: - reader = r - case err := <-errs: - return nil, 0, err - } - defer reader.Close() - footer := make([]byte, footerSize) - if _, err := io.ReadFull(reader, footer); err != nil { - return nil, 0, err + + for soe := range streamsOrErrors { + if soe.stream != nil { + _, err = io.ReadFull(soe.stream, footer) + _ = soe.stream.Close() + } + if soe.err != nil && err == nil { + err = soe.err + } } /* Read the ToC offset: @@ -89,44 +89,55 @@ func readEstargzChunkedManifest(blobStream ImageSourceSeekable, blobSize int64, Offset: uint64(tocOffset), Length: uint64(size), } - parts, errs, err = blobStream.GetBlobAt([]ImageSourceChunk{chunk}) - if err != nil { - return nil, 0, err - } - - var tocReader io.ReadCloser - select { - case r := <-parts: - tocReader = r - case err := <-errs: - return nil, 0, err - } - defer tocReader.Close() - - r, err := pgzip.NewReader(tocReader) - if err != nil { - return nil, 0, err - } - defer r.Close() - - aTar := archivetar.NewReader(r) - - header, err := aTar.Next() + streamsOrErrors, err = getBlobAt(blobStream, []ImageSourceChunk{chunk}, 1) if err != nil { return nil, 0, err } - // set a reasonable limit - if header.Size > (1<<20)*50 { - return nil, 0, errors.New("manifest too big") - } - manifestUncompressed := make([]byte, header.Size) - if _, err := io.ReadFull(aTar, manifestUncompressed); err != nil { - return nil, 0, err + var manifestUncompressed []byte + + for soe := range streamsOrErrors { + if soe.stream != nil { + err1 := func() error { + defer soe.stream.Close() + + r, err := pgzip.NewReader(soe.stream) + if err != nil { + return err + } + defer r.Close() + + aTar := archivetar.NewReader(r) + + header, err := aTar.Next() + if err != nil { + return err + } + // set a reasonable limit + if header.Size > (1<<20)*50 { + return errors.New("manifest too big") + } + + manifestUncompressed = make([]byte, header.Size) + if _, err := io.ReadFull(aTar, manifestUncompressed); err != nil { + return err + } + return nil + }() + if err == nil { + err = err1 + } + } + if soe.err != nil && err == nil { + err = soe.err + } } manifestDigester := digest.Canonical.Digester() manifestChecksum := manifestDigester.Hash() + if manifestUncompressed == nil { + return nil, 0, errors.New("manifest not found") + } if _, err := manifestChecksum.Write(manifestUncompressed); err != nil { return nil, 0, err } @@ -175,26 +186,29 @@ func readZstdChunkedManifest(blobStream ImageSourceSeekable, tocDigest digest.Di if tarSplitChunk.Offset > 0 { chunks = append(chunks, tarSplitChunk) } - parts, errs, err := blobStream.GetBlobAt(chunks) + + streamsOrErrors, err := getBlobAt(blobStream, chunks, len(chunks)) if err != nil { return nil, nil, nil, 0, err } + defer func() { + for soe := range streamsOrErrors { + if soe.stream != nil { + _ = soe.stream.Close() + } + } + }() + readBlob := func(len uint64) ([]byte, error) { - var reader io.ReadCloser - select { - case r := <-parts: - reader = r - case err := <-errs: - return nil, err + soe := <-streamsOrErrors + if soe.err != nil { + return nil, soe.err } + defer soe.stream.Close() blob := make([]byte, len) - if _, err := io.ReadFull(reader, blob); err != nil { - reader.Close() - return nil, err - } - if err := reader.Close(); err != nil { + if _, err := io.ReadFull(soe.stream, blob); err != nil { return nil, err } return blob, nil diff --git a/pkg/chunked/storage_linux.go b/pkg/chunked/storage_linux.go index 8ecbfb9826..7d60e68b39 100644 --- a/pkg/chunked/storage_linux.go +++ b/pkg/chunked/storage_linux.go @@ -1141,12 +1141,68 @@ func makeEntriesFlat(mergedEntries []fileMetadata) ([]fileMetadata, error) { return new, nil } -func (c *chunkedDiffer) copyAllBlobToFile(destination *os.File) (digest.Digest, error) { - var payload io.ReadCloser - var streams chan io.ReadCloser - var errs chan error - var err error +type streamOrErr struct { + stream io.ReadCloser + err error +} + +func getBlobAt(is ImageSourceSeekable, chunksToRequest []ImageSourceChunk, maxStreams int) (chan streamOrErr, error) { + streams, errs, err := is.GetBlobAt(chunksToRequest) + if err != nil { + return nil, err + } + stream := make(chan streamOrErr) + go func() { + tooManyStreams := false + defer close(stream) + streamsSoFar := 0 + loop: + for { + select { + case p, ok := <-streams: + if !ok { + streams = nil + break loop + } + if maxStreams > 0 && streamsSoFar >= maxStreams { + tooManyStreams = true + _ = p.Close() + continue + } + streamsSoFar++ + stream <- streamOrErr{stream: p} + case err, ok := <-errs: + if !ok { + errs = nil + break loop + } + stream <- streamOrErr{err: err} + } + } + if streams != nil { + for p := range streams { + if maxStreams > 0 && streamsSoFar >= maxStreams { + tooManyStreams = true + _ = p.Close() + continue + } + streamsSoFar++ + stream <- streamOrErr{stream: p} + } + } + if errs != nil { + for err := range errs { + stream <- streamOrErr{err: err} + } + } + if tooManyStreams { + stream <- streamOrErr{err: fmt.Errorf("too many streams returned, requested %d, got %d", len(chunksToRequest), streamsSoFar)} + } + }() + return stream, nil +} +func (c *chunkedDiffer) copyAllBlobToFile(destination *os.File) (digest.Digest, error) { chunksToRequest := []ImageSourceChunk{ { Offset: 0, @@ -1154,28 +1210,24 @@ func (c *chunkedDiffer) copyAllBlobToFile(destination *os.File) (digest.Digest, }, } - streams, errs, err = c.stream.GetBlobAt(chunksToRequest) + streamsOrErrors, err := getBlobAt(c.stream, chunksToRequest, 1) if err != nil { return "", err } - select { - case p := <-streams: - payload = p - case err := <-errs: - return "", err - } - if payload == nil { - return "", errors.New("invalid stream returned") - } - defer payload.Close() originalRawDigester := digest.Canonical.Digester() + for soe := range streamsOrErrors { + if soe.stream != nil { + r := io.TeeReader(soe.stream, originalRawDigester.Hash()) - r := io.TeeReader(payload, originalRawDigester.Hash()) - - // copy the entire tarball and compute its digest - _, err = io.CopyBuffer(destination, r, c.copyBuffer) - + // copy the entire tarball and compute its digest + _, err = io.CopyBuffer(destination, r, c.copyBuffer) + _ = soe.stream.Close() + } + if soe.err != nil && err == nil { + err = soe.err + } + } return originalRawDigester.Digest(), err } diff --git a/pkg/chunked/storage_linux_test.go b/pkg/chunked/storage_linux_test.go new file mode 100644 index 0000000000..b840ba57c8 --- /dev/null +++ b/pkg/chunked/storage_linux_test.go @@ -0,0 +1,165 @@ +package chunked + +import ( + "bytes" + "errors" + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Mock for ImageSourceSeekable +type mockImageSource struct { + streams chan io.ReadCloser + errors chan error +} + +func (m *mockImageSource) GetBlobAt(chunks []ImageSourceChunk) (chan io.ReadCloser, chan error, error) { + return m.streams, m.errors, nil +} + +type mockReadCloser struct { + reader io.Reader + closed bool +} + +func (m *mockReadCloser) Read(p []byte) (int, error) { + return m.reader.Read(p) +} + +func (m *mockReadCloser) Close() error { + m.closed = true + return nil +} + +func mockReadCloserFromContent(content string) *mockReadCloser { + return &mockReadCloser{reader: bytes.NewBufferString(content), closed: false} +} + +func TestGetBlobAtNormalOperation(t *testing.T) { + errors := make(chan error, 1) + expectedStreams := []string{"stream1", "stream2"} + streamsObjs := []*mockReadCloser{ + mockReadCloserFromContent(expectedStreams[0]), + mockReadCloserFromContent(expectedStreams[1]), + } + streams := make(chan io.ReadCloser, len(streamsObjs)) + + for _, s := range streamsObjs { + streams <- s + } + close(streams) + close(errors) + + is := &mockImageSource{streams: streams, errors: errors} + + resultChan, err := getBlobAt(is, nil, 0) + require.NoError(t, err) + + i := 0 + for result := range resultChan { + assert.NoError(t, result.err) + buf := new(bytes.Buffer) + _, _ = buf.ReadFrom(result.stream) + result.stream.Close() + assert.Equal(t, expectedStreams[i], buf.String()) + i++ + } + assert.Equal(t, len(expectedStreams), i) + for _, s := range streamsObjs { + assert.True(t, s.closed) + } +} + +func TestGetBlobAtMaxStreams(t *testing.T) { + streams := make(chan io.ReadCloser, 5) + errors := make(chan error) + + streamsObjs := []*mockReadCloser{} + + for i := 1; i <= 5; i++ { + s := mockReadCloserFromContent(fmt.Sprintf("stream%d", i)) + streamsObjs = append(streamsObjs, s) + streams <- s + } + close(streams) + close(errors) + + is := &mockImageSource{streams: streams, errors: errors} + + resultChan, err := getBlobAt(is, nil, 3) + require.NoError(t, err) + + count := 0 + receivedErr := false + for result := range resultChan { + if result.err != nil { + receivedErr = true + } else { + result.stream.Close() + count++ + } + } + assert.True(t, receivedErr) + assert.Equal(t, 3, count) + for _, s := range streamsObjs { + assert.True(t, s.closed) + } +} + +func TestGetBlobAtWithErrors(t *testing.T) { + streams := make(chan io.ReadCloser) + errorsC := make(chan error, 2) + + errorsC <- errors.New("error1") + errorsC <- errors.New("error2") + close(streams) + close(errorsC) + + is := &mockImageSource{streams: streams, errors: errorsC} + + resultChan, err := getBlobAt(is, nil, 0) + require.NoError(t, err) + + expectedErrors := []string{"error1", "error2"} + i := 0 + for result := range resultChan { + assert.Nil(t, result.stream) + assert.NotNil(t, result.err) + if result.err != nil { + assert.Equal(t, expectedErrors[i], result.err.Error()) + } + i++ + } + assert.Equal(t, len(expectedErrors), i) +} + +func TestGetBlobAtMixedStreamsAndErrors(t *testing.T) { + streams := make(chan io.ReadCloser, 2) + errorsC := make(chan error, 1) + + streams <- mockReadCloserFromContent("stream1") + errorsC <- errors.New("error1") + close(streams) + close(errorsC) + + is := &mockImageSource{streams: streams, errors: errorsC} + + resultChan, err := getBlobAt(is, nil, 1) + require.NoError(t, err) + + var receivedStreams int + var receivedErrors int + for result := range resultChan { + if result.err != nil { + receivedErrors++ + } else { + receivedStreams++ + } + } + assert.Equal(t, 1, receivedStreams) + assert.Equal(t, 1, receivedErrors) +}