Skip to content

Commit

Permalink
feat(storage/transfermanager): checksum full object downloads (#10569)
Browse files Browse the repository at this point in the history
  • Loading branch information
BrennaEpp authored Jul 30, 2024
1 parent 123c886 commit c366c90
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 71 deletions.
281 changes: 217 additions & 64 deletions storage/transfermanager/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"context"
"errors"
"fmt"
"hash"
"hash/crc32"
"io"
"io/fs"
"math"
Expand All @@ -31,6 +33,12 @@ import (
"google.golang.org/api/iterator"
)

// maxChecksumZeroArraySize is the maximum amount of memory to allocate for
// updating the checksum. A larger size will occupy more memory but will require
// fewer updates when computing the crc32c of a full object.
// TODO: test the performance of smaller values for this.
const maxChecksumZeroArraySize = 4 * 1024 * 1024

// Downloader manages a set of parallelized downloads.
type Downloader struct {
client *storage.Client
Expand Down Expand Up @@ -288,7 +296,7 @@ func (d *Downloader) addNewInputs(inputs []DownloadObjectInput) {
}

func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutput) {
copiedResult := *result // make a copy so that callbacks do not affect the result
copiedResult := *result // make a copy so that callbacks do not affect the result

if input.directory {
f := input.Destination.(*os.File)
Expand All @@ -305,7 +313,6 @@ func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutpu
input.directoryObjectOutputs <- copiedResult
}
}
// TODO: check checksum if full object

if d.config.asynchronous || input.directory {
input.Callback(result)
Expand Down Expand Up @@ -337,27 +344,10 @@ func (d *Downloader) downloadWorker() {
break // no more work; exit
}

out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)

if input.shard == 0 {
if out.Err != nil {
// Don't queue more shards if the first failed.
d.addResult(input, out)
} else {
numShards := numShards(out.Attrs, input.Range, d.config.partSize)

if numShards <= 1 {
// Download completed with a single shard.
d.addResult(input, out)
} else {
// Queue more shards.
outs := d.queueShards(input, out.Attrs.Generation, numShards)
// Start a goroutine that gathers shards sent to the output
// channel and adds the result once it has received all shards.
go d.gatherShards(input, outs, numShards)
}
}
d.startDownload(input)
} else {
out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)
// If this isn't the first shard, send to the output channel specific to the object.
// This should never block since the channel is buffered to exactly the number of shards.
input.shardOutputs <- out
Expand All @@ -366,6 +356,47 @@ func (d *Downloader) downloadWorker() {
d.workers.Done()
}

// startDownload downloads the first shard and schedules subsequent shards
// if necessary.
func (d *Downloader) startDownload(input *DownloadObjectInput) {
var out *DownloadOutput

// Full object read. Request the full object and only read partSize bytes
// (or the full object, if smaller than partSize), so that we can avoid a
// metadata call to grab the CRC32C for JSON downloads.
if fullObjectRead(input.Range) {
input.checkCRC = true
out = input.downloadFirstShard(d.client, d.config.perOperationTimeout, d.config.partSize)
} else {
out = input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)
}

if out.Err != nil {
// Don't queue more shards if the first failed.
d.addResult(input, out)
return
}

numShards := numShards(out.Attrs, input.Range, d.config.partSize)
input.checkCRC = input.checkCRC && !out.Attrs.Decompressed // do not checksum if the object was decompressed

if numShards > 1 {
outs := d.queueShards(input, out.Attrs.Generation, numShards)
// Start a goroutine that gathers shards sent to the output
// channel and adds the result once it has received all shards.
go d.gatherShards(input, out, outs, numShards, out.crc32c)

} else {
// Download completed with a single shard.
if input.checkCRC {
if err := checksumObject(out.crc32c, out.Attrs.CRC32C); err != nil {
out.Err = err
}
}
d.addResult(input, out)
}
}

// queueShards queues all subsequent shards of an object after the first.
// The results should be forwarded to the returned channel.
func (d *Downloader) queueShards(in *DownloadObjectInput, gen int64, shards int) <-chan *DownloadOutput {
Expand Down Expand Up @@ -397,12 +428,12 @@ var errCancelAllShards = errors.New("cancelled because another shard failed")
// It will add the result to the Downloader once it has received all shards.
// gatherShards cancels remaining shards if any shard errored.
// It does not do any checking to verify that shards are for the same object.
func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *DownloadOutput, shards int) {
func (d *Downloader) gatherShards(in *DownloadObjectInput, out *DownloadOutput, outs <-chan *DownloadOutput, shards int, firstPieceCRC uint32) {
errs := []error{}
var shardOut *DownloadOutput
orderedChecksums := make([]crc32cPiece, shards-1)

for i := 1; i < shards; i++ {
// Add monitoring here? This could hang if any individual piece does.
shardOut = <-outs
shardOut := <-outs

// We can ignore errors that resulted from a previous error.
// Note that we may still get some cancel errors if they
Expand All @@ -412,20 +443,30 @@ func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *Download
errs = append(errs, shardOut.Err)
in.cancelCtx(errCancelAllShards)
}

orderedChecksums[shardOut.shard-1] = crc32cPiece{sum: shardOut.crc32c, length: shardOut.shardLength}
}

// All pieces gathered.
if len(errs) == 0 && in.checkCRC && out.Attrs != nil {
fullCrc := joinCRC32C(firstPieceCRC, orderedChecksums)
if err := checksumObject(fullCrc, out.Attrs.CRC32C); err != nil {
errs = append(errs, err)
}
}

// All pieces gathered; return output. Any shard output will do.
shardOut.Range = in.Range
// Prepare output.
out.Range = in.Range
if len(errs) != 0 {
shardOut.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...))
out.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...))
}
if shardOut.Attrs != nil {
shardOut.Attrs.StartOffset = 0
if out.Attrs != nil {
out.Attrs.StartOffset = 0
if in.Range != nil {
shardOut.Attrs.StartOffset = in.Range.Offset
out.Attrs.StartOffset = in.Range.Offset
}
}
d.addResult(in, shardOut)
d.addResult(in, out)
}

// gatherObjectOutputs receives from the given channel exactly numObjects times.
Expand Down Expand Up @@ -563,45 +604,18 @@ type DownloadObjectInput struct {
shardOutputs chan<- *DownloadOutput
directory bool // input was queued by calling DownloadDirectory
directoryObjectOutputs chan<- DownloadOutput
checkCRC bool
}

// downloadShard will read a specific object piece into in.Destination.
// If timeout is less than 0, no timeout is set.
func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) {
out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range}

// Set timeout.
ctx := in.ctx
if timeout > 0 {
c, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = c
}

// The first shard will be sent as download many, since we do not know yet
// if it will be sharded.
method := downloadMany
if in.shard != 0 {
method = downloadSharded
}
ctx = setUsageMetricHeader(ctx, method)

// Set options on the object.
o := client.Bucket(in.Bucket).Object(in.Object)

if in.Conditions != nil {
o = o.If(*in.Conditions)
}
if in.Generation != nil {
o = o.Generation(*in.Generation)
}
if len(in.EncryptionKey) > 0 {
o = o.Key(in.EncryptionKey)
}

objRange := shardRange(in.Range, partSize, in.shard)
ctx := in.setOptionsOnContext(timeout)
o := in.setOptionsOnObject(client)

// Read.
r, err := o.NewRangeReader(ctx, objRange.Offset, objRange.Length)
if err != nil {
out.Err = err
Expand All @@ -618,9 +632,63 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim
}
}

w := io.NewOffsetWriter(in.Destination, offset)
_, err = io.Copy(w, r)
var w io.Writer
w = io.NewOffsetWriter(in.Destination, offset)

var crcHash hash.Hash32
if in.checkCRC {
crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli))
w = io.MultiWriter(w, crcHash)
}

n, err := io.Copy(w, r)
if err != nil {
out.Err = err
r.Close()
return
}

if err = r.Close(); err != nil {
out.Err = err
return
}

out.Attrs = &r.Attrs
out.shard = in.shard
out.shardLength = n
if in.checkCRC {
out.crc32c = crcHash.Sum32()
}
return
}

// downloadFirstShard will read the first object piece into in.Destination.
// If timeout is less than 0, no timeout is set.
func (in *DownloadObjectInput) downloadFirstShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) {
out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range}

ctx := in.setOptionsOnContext(timeout)
o := in.setOptionsOnObject(client)

r, err := o.NewReader(ctx)
if err != nil {
out.Err = err
return
}

var w io.Writer
w = io.NewOffsetWriter(in.Destination, 0)

var crcHash hash.Hash32
if in.checkCRC {
crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli))
w = io.MultiWriter(w, crcHash)
}

// Copy only the first partSize bytes before closing the reader.
// If we encounter an EOF, the file was smaller than partSize.
n, err := io.CopyN(w, r, partSize)
if err != nil && err != io.EOF {
out.Err = err
r.Close()
return
Expand All @@ -632,9 +700,45 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim
}

out.Attrs = &r.Attrs
out.shard = in.shard
out.shardLength = n
if in.checkCRC {
out.crc32c = crcHash.Sum32()
}
return
}

func (in *DownloadObjectInput) setOptionsOnContext(timeout time.Duration) context.Context {
ctx := in.ctx
if timeout > 0 {
c, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
ctx = c
}

// The first shard will be sent as download many, since we do not know yet
// if it will be sharded.
method := downloadMany
if in.shard != 0 {
method = downloadSharded
}
return setUsageMetricHeader(ctx, method)
}

func (in *DownloadObjectInput) setOptionsOnObject(client *storage.Client) *storage.ObjectHandle {
o := client.Bucket(in.Bucket).Object(in.Object)
if in.Conditions != nil {
o = o.If(*in.Conditions)
}
if in.Generation != nil {
o = o.Generation(*in.Generation)
}
if len(in.EncryptionKey) > 0 {
o = o.Key(in.EncryptionKey)
}
return o
}

// DownloadDirectoryInput is the input for a directory to download.
type DownloadDirectoryInput struct {
// Bucket is the bucket in GCS to download from. Required.
Expand Down Expand Up @@ -686,6 +790,10 @@ type DownloadOutput struct {
Range *DownloadRange // requested range, if it was specified
Err error // error occurring during download
Attrs *storage.ReaderObjectAttrs // attributes of downloaded object, if successful

shard int
shardLength int64
crc32c uint32
}

// TODO: use built-in after go < 1.21 is dropped.
Expand Down Expand Up @@ -784,3 +892,48 @@ func setUsageMetricHeader(ctx context.Context, method string) context.Context {
header := fmt.Sprintf("%s/%s", usageMetricKey, method)
return callctx.SetHeaders(ctx, xGoogHeaderKey, header)
}

type crc32cPiece struct {
sum uint32 // crc32c checksum of the piece
length int64 // number of bytes in this piece
}

// joinCRC32C pieces together the initial checksum with the orderedChecksums
// provided to calculate the checksum of the whole.
func joinCRC32C(initialChecksum uint32, orderedChecksums []crc32cPiece) uint32 {
base := initialChecksum

zeroes := make([]byte, maxChecksumZeroArraySize)
for _, part := range orderedChecksums {
// Precondition Base (flip every bit)
base ^= 0xFFFFFFFF

// Zero pad base crc32c. To conserve memory, do so with only maxChecksumZeroArraySize
// at a time. Reuse the zeroes array where possible.
var padded int64 = 0
for padded < part.length {
desiredZeroes := min(part.length-padded, maxChecksumZeroArraySize)
base = crc32.Update(base, crc32.MakeTable(crc32.Castagnoli), zeroes[:desiredZeroes])
padded += desiredZeroes
}

// Postcondition Base (same as precondition, this switches the bits back)
base ^= 0xFFFFFFFF

// Bitwise OR between Base and Part to produce a new Base
base ^= part.sum
}
return base
}

func fullObjectRead(r *DownloadRange) bool {
return r == nil || (r.Offset == 0 && r.Length < 0)
}

func checksumObject(got, want uint32) error {
// Only checksum the object if we have a valid CRC32C.
if want != 0 && want != got {
return fmt.Errorf("bad CRC on read: got %d, want %d", got, want)
}
return nil
}
Loading

0 comments on commit c366c90

Please sign in to comment.