Skip to content

Commit

Permalink
chunked: rework GetBlobAt usage
Browse files Browse the repository at this point in the history
rewrite how the result from GetBlobAt is used, to make sure 1) that
the streams are always closed, and 2) that any error is processed.

Closes: https://issues.redhat.com/browse/OCPBUGS-43968

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
  • Loading branch information
giuseppe committed Nov 6, 2024
1 parent ad5f2a4 commit c782691
Show file tree
Hide file tree
Showing 3 changed files with 303 additions and 75 deletions.
122 changes: 68 additions & 54 deletions pkg/chunked/compression_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ func readEstargzChunkedManifest(blobStream ImageSourceSeekable, blobSize int64,
Offset: uint64(blobSize - footerSize),
Length: uint64(footerSize),
}
parts, errs, err := blobStream.GetBlobAt([]ImageSourceChunk{chunk})

footer := make([]byte, footerSize)
streamsOrErrors, err := getBlobAt(blobStream, []ImageSourceChunk{chunk}, 1)
if err != nil {
return nil, 0, err
}
var reader io.ReadCloser
select {
case r := <-parts:
reader = r
case err := <-errs:
return nil, 0, err
}
defer reader.Close()
footer := make([]byte, footerSize)
if _, err := io.ReadFull(reader, footer); err != nil {
return nil, 0, err

for soe := range streamsOrErrors {
if soe.stream != nil {
_, err = io.ReadFull(soe.stream, footer)
_ = soe.stream.Close()
}
if soe.err != nil && err == nil {
err = soe.err
}
}

/* Read the ToC offset:
Expand All @@ -89,44 +89,55 @@ func readEstargzChunkedManifest(blobStream ImageSourceSeekable, blobSize int64,
Offset: uint64(tocOffset),
Length: uint64(size),
}
parts, errs, err = blobStream.GetBlobAt([]ImageSourceChunk{chunk})
if err != nil {
return nil, 0, err
}

var tocReader io.ReadCloser
select {
case r := <-parts:
tocReader = r
case err := <-errs:
return nil, 0, err
}
defer tocReader.Close()

r, err := pgzip.NewReader(tocReader)
if err != nil {
return nil, 0, err
}
defer r.Close()

aTar := archivetar.NewReader(r)

header, err := aTar.Next()
streamsOrErrors, err = getBlobAt(blobStream, []ImageSourceChunk{chunk}, 1)
if err != nil {
return nil, 0, err
}
// set a reasonable limit
if header.Size > (1<<20)*50 {
return nil, 0, errors.New("manifest too big")
}

manifestUncompressed := make([]byte, header.Size)
if _, err := io.ReadFull(aTar, manifestUncompressed); err != nil {
return nil, 0, err
var manifestUncompressed []byte

for soe := range streamsOrErrors {
if soe.stream != nil {
err1 := func() error {
defer soe.stream.Close()

r, err := pgzip.NewReader(soe.stream)
if err != nil {
return err
}
defer r.Close()

aTar := archivetar.NewReader(r)

header, err := aTar.Next()
if err != nil {
return err
}
// set a reasonable limit
if header.Size > (1<<20)*50 {
return errors.New("manifest too big")
}

manifestUncompressed = make([]byte, header.Size)
if _, err := io.ReadFull(aTar, manifestUncompressed); err != nil {
return err
}
return nil
}()
if err == nil {
err = err1
}
}
if soe.err != nil && err == nil {
err = soe.err
}
}

manifestDigester := digest.Canonical.Digester()
manifestChecksum := manifestDigester.Hash()
if manifestUncompressed == nil {
return nil, 0, errors.New("manifest not found")
}
if _, err := manifestChecksum.Write(manifestUncompressed); err != nil {
return nil, 0, err
}
Expand Down Expand Up @@ -175,26 +186,29 @@ func readZstdChunkedManifest(blobStream ImageSourceSeekable, tocDigest digest.Di
if tarSplitChunk.Offset > 0 {
chunks = append(chunks, tarSplitChunk)
}
parts, errs, err := blobStream.GetBlobAt(chunks)

streamsOrErrors, err := getBlobAt(blobStream, chunks, len(chunks))
if err != nil {
return nil, nil, nil, 0, err
}

defer func() {
for soe := range streamsOrErrors {
if soe.stream != nil {
_ = soe.stream.Close()
}
}
}()

readBlob := func(len uint64) ([]byte, error) {
var reader io.ReadCloser
select {
case r := <-parts:
reader = r
case err := <-errs:
return nil, err
soe := <-streamsOrErrors
if soe.err != nil {
return nil, soe.err
}
defer soe.stream.Close()

blob := make([]byte, len)
if _, err := io.ReadFull(reader, blob); err != nil {
reader.Close()
return nil, err
}
if err := reader.Close(); err != nil {
if _, err := io.ReadFull(soe.stream, blob); err != nil {
return nil, err
}
return blob, nil
Expand Down
91 changes: 70 additions & 21 deletions pkg/chunked/storage_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -1141,41 +1141,90 @@ func makeEntriesFlat(mergedEntries []fileMetadata) ([]fileMetadata, error) {
return new, nil
}

func (c *chunkedDiffer) copyAllBlobToFile(destination *os.File) (digest.Digest, error) {
var payload io.ReadCloser
var streams chan io.ReadCloser
var errs chan error
var err error
type streamOrErr struct {
stream io.ReadCloser
err error
}

func getBlobAt(is ImageSourceSeekable, chunksToRequest []ImageSourceChunk, maxStreams int) (chan streamOrErr, error) {
streams, errs, err := is.GetBlobAt(chunksToRequest)
if err != nil {
return nil, err
}
stream := make(chan streamOrErr)
go func() {
defer close(stream)
streamsSoFar := 0
loop:
for {
select {
case p, ok := <-streams:
if !ok {
streams = nil
break loop
}
if maxStreams > 0 && streamsSoFar >= maxStreams {
_ = p.Close()
continue
}
streamsSoFar++
stream <- streamOrErr{stream: p}
case err, ok := <-errs:
if !ok {
errs = nil
break loop
}
stream <- streamOrErr{err: err}
}
}
if streams != nil {
for p := range streams {
if maxStreams > 0 && streamsSoFar >= maxStreams {
_ = p.Close()
continue
}
streamsSoFar++
stream <- streamOrErr{stream: p}
}
}
if errs != nil {
for err := range errs {
stream <- streamOrErr{err: err}
}
}
if maxStreams > 0 && streamsSoFar >= maxStreams {
stream <- streamOrErr{err: fmt.Errorf("too many streams returned")}
}
}()
return stream, nil
}

func (c *chunkedDiffer) copyAllBlobToFile(destination *os.File) (digest.Digest, error) {
chunksToRequest := []ImageSourceChunk{
{
Offset: 0,
Length: uint64(c.blobSize),
},
}

streams, errs, err = c.stream.GetBlobAt(chunksToRequest)
streamsOrErrors, err := getBlobAt(c.stream, chunksToRequest, 1)
if err != nil {
return "", err
}
select {
case p := <-streams:
payload = p
case err := <-errs:
return "", err
}
if payload == nil {
return "", errors.New("invalid stream returned")
}
defer payload.Close()

originalRawDigester := digest.Canonical.Digester()
for soe := range streamsOrErrors {
if soe.stream != nil {
r := io.TeeReader(soe.stream, originalRawDigester.Hash())

r := io.TeeReader(payload, originalRawDigester.Hash())

// copy the entire tarball and compute its digest
_, err = io.CopyBuffer(destination, r, c.copyBuffer)

// copy the entire tarball and compute its digest
_, err = io.CopyBuffer(destination, r, c.copyBuffer)
_ = soe.stream.Close()
}
if soe.err != nil && err == nil {
err = soe.err
}
}
return originalRawDigester.Digest(), err
}

Expand Down
Loading

0 comments on commit c782691

Please sign in to comment.