From 6d54e76abe1c9376adb2f5d0469983fc6b1ec9f4 Mon Sep 17 00:00:00 2001 From: Rubens Farias Date: Thu, 29 Oct 2020 16:03:45 -0400 Subject: [PATCH] Add compressed writes to cas.go. This follows the current tentative API being worked on in bazelbuild/remote-apis#168. While there's technically room for it to change, it has reached a somewhat stable point worth implementing. --- go/pkg/client/bytestream.go | 4 ++ go/pkg/client/cas.go | 39 ++++++++++++++-- go/pkg/client/cas_test.go | 29 ++++++++---- go/pkg/client/client.go | 53 +++++++++++++-------- go/pkg/fakes/BUILD.bazel | 1 + go/pkg/fakes/cas.go | 91 +++++++++++++++++++++++++++++-------- go/pkg/fakes/server.go | 5 +- 7 files changed, 169 insertions(+), 53 deletions(-) diff --git a/go/pkg/client/bytestream.go b/go/pkg/client/bytestream.go index 0cad88007..82f388d1c 100644 --- a/go/pkg/client/bytestream.go +++ b/go/pkg/client/bytestream.go @@ -37,6 +37,10 @@ func (c *Client) WriteChunked(ctx context.Context, name string, ch *chunker.Chun return err } if chunk.Offset == 0 { + // Notice that the digest in the chunker might be misleading. + // Specifically, for compressed blob uploads, the resource + // name should include the uncompressed digest - while chunker + // should be including the compressed digest. req.ResourceName = name } req.WriteOffset = chunk.Offset diff --git a/go/pkg/client/cas.go b/go/pkg/client/cas.go index 861b67a83..5140a195c 100644 --- a/go/pkg/client/cas.go +++ b/go/pkg/client/cas.go @@ -25,6 +25,9 @@ import ( log "github.com/golang/glog" ) +// DefaultCompressedWritesThreshold is the default threshold for writing blobs compressed on ByteStream.Write RPCs. +const DefaultCompressedWritesThreshold = 1024 + // UploadIfMissing stores a number of uploadable items. // It first queries the CAS to see which items are missing and only uploads those that are. // Returns a slice of the missing digests. @@ -87,8 +90,12 @@ func (c *Client) UploadIfMissing(ctx context.Context, data ...*chunker.Chunker) } else { log.V(3).Infof("Uploading single blob with digest %s", batch[0]) ch := chunkers[batch[0]] - dg := ch.Digest() - if err := c.WriteChunked(eCtx, c.ResourceNameWrite(dg.Hash, dg.Size), ch); err != nil { + var rscName string + var err error + if rscName, err = c.maybeCompressBlob(ch); err != nil { + return err + } + if err = c.WriteChunked(eCtx, rscName, ch); err != nil { return err } } @@ -135,7 +142,26 @@ func (c *Client) WriteProto(ctx context.Context, msg proto.Message) (digest.Dige func (c *Client) WriteBlob(ctx context.Context, blob []byte) (digest.Digest, error) { ch := chunker.NewFromBlob(blob, int(c.ChunkMaxSize)) dg := ch.Digest() - return dg, c.WriteChunked(ctx, c.ResourceNameWrite(dg.Hash, dg.Size), ch) + + name, err := c.maybeCompressBlob(ch) + if err != nil { + return dg, err + } + + return dg, c.WriteChunked(ctx, name, ch) +} + +// maybeCompressBlob will, depending on the client configuration, set the blobs to be +// read compressed. It returns the appropriate resource name. +func (c *Client) maybeCompressBlob(ch *chunker.Chunker) (string, error) { + dg := ch.Digest() + if c.CompressedWritesThreshold < 0 || int64(c.CompressedWritesThreshold) > ch.Digest().Size { + return c.ResourceNameWrite(dg.Hash, dg.Size), nil + } + if err := chunker.CompressChunker(ch); err != nil { + return "", err + } + return c.ResourceNameCompressedWrite(dg.Hash, dg.Size), nil } // BatchWriteBlobs uploads a number of blobs to the CAS. They must collectively be below the @@ -514,6 +540,13 @@ func (c *Client) ResourceNameWrite(hash string, sizeBytes int64) string { return fmt.Sprintf("%s/uploads/%s/blobs/%s/%d", c.InstanceName, uuid.New(), hash, sizeBytes) } +// ResourceNameCompressedWrite generates a valid write resource name. +// TODO(rubensf): Converge compressor to proto in https://github.com/bazelbuild/remote-apis/pull/168 once +// that gets merged in. +func (c *Client) ResourceNameCompressedWrite(hash string, sizeBytes int64) string { + return fmt.Sprintf("%s/uploads/%s/compressed-blobs/zstd/%s/%d", c.InstanceName, uuid.New(), hash, sizeBytes) +} + // GetDirectoryTree returns the entire directory tree rooted at the given digest (which must target // a Directory stored in the CAS). func (c *Client) GetDirectoryTree(ctx context.Context, d *repb.Digest) (result []*repb.Directory, err error) { diff --git a/go/pkg/client/cas_test.go b/go/pkg/client/cas_test.go index 12834d8e3..fecce6e9e 100644 --- a/go/pkg/client/cas_test.go +++ b/go/pkg/client/cas_test.go @@ -284,7 +284,7 @@ func TestWrite(t *testing.T) { } for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { + testFunc := func(t *testing.T) { gotDg, err := c.WriteBlob(ctx, tc.blob) if err != nil { t.Errorf("c.WriteBlob(ctx, blob) gave error %s, wanted nil", err) @@ -299,7 +299,13 @@ func TestWrite(t *testing.T) { if dg != gotDg { t.Errorf("c.WriteBlob(ctx, blob) had diff on digest returned (want %s, got %s)", dg, gotDg) } - }) + } + + // Harder to write in a for loop since it -1/0 isn't an intuitive "enabled/disabled" + c.CompressedWritesThreshold = -1 + t.Run(tc.name+" - no compression", testFunc) + c.CompressedWritesThreshold = 0 + t.Run(tc.name+" - with compression", testFunc) } } @@ -442,7 +448,7 @@ func TestUpload(t *testing.T) { t.Run(fmt.Sprintf("UsingBatch:%t", ub), func(t *testing.T) { ub.Apply(c) for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { + testFunc := func(t *testing.T) { fake.Clear() if tc.concurrency > 0 { tc.concurrency.Apply(c) @@ -467,12 +473,9 @@ func TestUpload(t *testing.T) { for _, dg := range missing { missingSet[dg] = struct{}{} } - for _, ch := range input { + for i, ch := range input { dg := ch.Digest() - blob, err := ch.FullData() - if err != nil { - t.Errorf("ch.FullData() returned an error: %v", err) - } + blob := tc.input[i] if present[dg] { if fake.BlobWrites(dg) > 0 { t.Errorf("blob %v with digest %s was uploaded even though it was already present in the CAS", blob, dg) @@ -485,7 +488,7 @@ func TestUpload(t *testing.T) { if gotBlob, ok := fake.Get(dg); !ok { t.Errorf("blob %v with digest %s was not uploaded, expected it to be present in the CAS", blob, dg) } else if !bytes.Equal(blob, gotBlob) { - t.Errorf("blob digest %s had diff on uploaded blob: want %v, got %v", dg, blob, gotBlob) + t.Errorf("blob digest %s had diff on uploaded blob: want %s, got %s", dg, blob, gotBlob) } if _, ok := missingSet[dg]; !ok { t.Errorf("Stats said that blob %v with digest %s was present in the CAS", blob, dg) @@ -494,7 +497,13 @@ func TestUpload(t *testing.T) { if fake.MaxConcurrency() > defaultCASConcurrency { t.Errorf("CAS concurrency %v was higher than max %v", fake.MaxConcurrency(), defaultCASConcurrency) } - }) + } + + // Harder to write in a for loop since it -1/0 isn't an intuitive "enabled/disabled" + c.CompressedWritesThreshold = -1 + t.Run(tc.name+" - no compression", testFunc) + c.CompressedWritesThreshold = 0 + t.Run(tc.name+" - with compression", testFunc) } }) } diff --git a/go/pkg/client/client.go b/go/pkg/client/client.go index dc5ae731c..9b63818c6 100644 --- a/go/pkg/client/client.go +++ b/go/pkg/client/client.go @@ -70,6 +70,11 @@ type Client struct { StartupCapabilities StartupCapabilities // ChunkMaxSize is maximum chunk size to use for CAS uploads/downloads. ChunkMaxSize ChunkMaxSize + // CompressedWritesThreshold is the threshold in bytes for which blobs are written compressed. + // Use 0 for all writes being compressed, and a negative number for all writes being uncompressed. + // TODO(rubensf): Make sure this will throw an error if the server doesn't support compression, + // pending https://github.com/bazelbuild/remote-apis/pull/168 being submitted. + CompressedWritesThreshold CompressedWritesThreshold // MaxBatchDigests is maximum amount of digests to batch in batched operations. MaxBatchDigests MaxBatchDigests // MaxBatchSize is maximum size in bytes of a batch request for batch operations. @@ -136,6 +141,13 @@ func (s ChunkMaxSize) Apply(c *Client) { c.ChunkMaxSize = s } +type CompressedWritesThreshold int + +// Apply sets the client's maximal chunk size s. +func (s CompressedWritesThreshold) Apply(c *Client) { + c.CompressedWritesThreshold = s +} + // UtilizeLocality is to specify whether client downloads files utilizing disk access locality. type UtilizeLocality bool @@ -427,26 +439,27 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts return nil, err } client := &Client{ - InstanceName: instanceName, - actionCache: regrpc.NewActionCacheClient(casConn), - byteStream: bsgrpc.NewByteStreamClient(casConn), - cas: regrpc.NewContentAddressableStorageClient(casConn), - execution: regrpc.NewExecutionClient(conn), - operations: opgrpc.NewOperationsClient(conn), - rpcTimeouts: DefaultRPCTimeouts, - Connection: conn, - CASConnection: casConn, - ChunkMaxSize: chunker.DefaultChunkSize, - MaxBatchDigests: DefaultMaxBatchDigests, - MaxBatchSize: DefaultMaxBatchSize, - DirMode: DefaultDirMode, - ExecutableMode: DefaultExecutableMode, - RegularMode: DefaultRegularMode, - useBatchOps: true, - StartupCapabilities: true, - casUploaders: make(chan bool, DefaultCASConcurrency), - casDownloaders: make(chan bool, DefaultCASConcurrency), - Retrier: RetryTransient(), + InstanceName: instanceName, + actionCache: regrpc.NewActionCacheClient(casConn), + byteStream: bsgrpc.NewByteStreamClient(casConn), + cas: regrpc.NewContentAddressableStorageClient(casConn), + execution: regrpc.NewExecutionClient(conn), + operations: opgrpc.NewOperationsClient(conn), + rpcTimeouts: DefaultRPCTimeouts, + Connection: conn, + CASConnection: casConn, + ChunkMaxSize: chunker.DefaultChunkSize, + CompressedWritesThreshold: DefaultCompressedWritesThreshold, + MaxBatchDigests: DefaultMaxBatchDigests, + MaxBatchSize: DefaultMaxBatchSize, + DirMode: DefaultDirMode, + ExecutableMode: DefaultExecutableMode, + RegularMode: DefaultRegularMode, + useBatchOps: true, + StartupCapabilities: true, + casUploaders: make(chan bool, DefaultCASConcurrency), + casDownloaders: make(chan bool, DefaultCASConcurrency), + Retrier: RetryTransient(), } for _, o := range opts { o.Apply(client) diff --git a/go/pkg/fakes/BUILD.bazel b/go/pkg/fakes/BUILD.bazel index 46cdd64c7..c10fd2c2d 100644 --- a/go/pkg/fakes/BUILD.bazel +++ b/go/pkg/fakes/BUILD.bazel @@ -22,6 +22,7 @@ go_library( "@com_github_golang_glog//:go_default_library", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_golang_protobuf//ptypes:go_default_library_gen", + "@com_github_klauspost_compress//zstd:go_default_library", "@com_github_pborman_uuid//:go_default_library", "@go_googleapis//google/bytestream:bytestream_go_proto", "@go_googleapis//google/longrunning:longrunning_go_proto", diff --git a/go/pkg/fakes/cas.go b/go/pkg/fakes/cas.go index 4756a5093..6f1d4b34e 100644 --- a/go/pkg/fakes/cas.go +++ b/go/pkg/fakes/cas.go @@ -14,6 +14,7 @@ import ( "github.com/bazelbuild/remote-apis-sdks/go/pkg/client" "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" "github.com/golang/protobuf/proto" + "github.com/klauspost/compress/zstd" "github.com/pborman/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -127,19 +128,29 @@ func (f *Writer) Write(stream bsgrpc.ByteStream_WriteServer) (err error) { } path := strings.Split(req.ResourceName, "/") - if len(path) != 6 || path[0] != "instance" || path[1] != "uploads" || path[3] != "blobs" { - return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs//\"") + if (len(path) != 6 && len(path) != 7) || path[0] != "instance" || path[1] != "uploads" || (path[3] != "blobs" && path[3] != "compressed-blobs") { + return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs|compressed-blobs///\"") + } + // indexOffset for all 4+ paths - `compressed-blobs` paths have one more element. + indexOffset := 0 + if path[3] == "compressed-blobs" { + indexOffset = 1 + // TODO(rubensf): Change this to all the possible compressors in https://github.com/bazelbuild/remote-apis/pull/168. + if path[4] != "zstd" { + return status.Error(codes.InvalidArgument, "test fake expected valid compressor, eg zstd") + } } - size, err := strconv.ParseInt(path[5], 10, 64) + + size, err := strconv.ParseInt(path[5+indexOffset], 10, 64) if err != nil { - return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs//\"") + return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs|compressed-blobs///\"") } - dg, e := digest.New(path[4], size) + dg, e := digest.New(path[4+indexOffset], size) if e != nil { - return status.Error(codes.InvalidArgument, "test fake expected valid digest as part of resource name of the form \"instance/uploads//blobs//\"") + return status.Error(codes.InvalidArgument, "test fake expected valid digest as part of resource name of the form \"instance/uploads//blobs|compressed-blobs///\"") } if uuid.Parse(path[2]) == nil { - return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs//\"") + return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs|compressed-blobs///\"") } res := req.ResourceName @@ -179,7 +190,23 @@ func (f *Writer) Write(stream bsgrpc.ByteStream_WriteServer) (err error) { return status.Errorf(codes.InvalidArgument, "reached end of stream before the client finished writing") } - f.Buf = buf.Bytes() + if path[3] == "compressed-blobs" { + if path[4] == "zstd" { + decoder, err := zstd.NewReader(nil) + if err != nil { + return status.Errorf(codes.Internal, "failed to initialize internal decoder: %v", err) + } + f.Buf, err = decoder.DecodeAll(buf.Bytes(), nil) + if err != nil { + return status.Errorf(codes.InvalidArgument, "served bytes can't be decompressed: %v", err) + } + } else { + return status.Errorf(codes.InvalidArgument, "%s compressor isn't supported", path[4]) + } + } else { + f.Buf = buf.Bytes() + } + cDg := digest.NewFromBlob(f.Buf) if dg != cDg { return status.Errorf(codes.InvalidArgument, "mismatched digest: received %s, computed %s", dg, cDg) @@ -210,13 +237,17 @@ type CAS struct { writeReqs int concReqs int maxConcReqs int + decoder *zstd.Decoder } // NewCAS returns a new empty fake CAS. -func NewCAS() *CAS { +func NewCAS() (*CAS, error) { c := &CAS{BatchSize: client.DefaultMaxBatchSize} c.Clear() - return c + + var err error + c.decoder, err = zstd.NewReader(nil) + return c, err } // Clear removes all results from the cache. @@ -470,19 +501,28 @@ func (f *CAS) Write(stream bsgrpc.ByteStream_WriteServer) (err error) { } path := strings.Split(req.ResourceName, "/") - if len(path) != 6 || path[0] != "instance" || path[1] != "uploads" || path[3] != "blobs" { - return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs//\"") + if (len(path) != 6 && len(path) != 7) || path[0] != "instance" || path[1] != "uploads" || (path[3] != "blobs" && path[3] != "compressed-blobs") { + return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs|compressed-blobs///\"") + } + // indexOffset for all 4+ paths - `compressed-blobs` paths have one more element. + indexOffset := 0 + if path[3] == "compressed-blobs" { + indexOffset = 1 + // TODO(rubensf): Change this to all the possible compressors in https://github.com/bazelbuild/remote-apis/pull/168. + if path[4] != "zstd" { + return status.Error(codes.InvalidArgument, "test fake expected valid compressor, eg zstd") + } } - size, err := strconv.ParseInt(path[5], 10, 64) + size, err := strconv.ParseInt(path[5+indexOffset], 10, 64) if err != nil { - return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs//\"") + return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs|compressed-blobs///\"") } - dg, err := digest.New(path[4], size) + dg, err := digest.New(path[4+indexOffset], size) if err != nil { - return status.Error(codes.InvalidArgument, "test fake expected a valid digest as part of the resource name: \"instance/uploads//blobs//\"") + return status.Error(codes.InvalidArgument, "test fake expected a valid digest as part of the resource name: \"instance/uploads//blobs|compressed-blobs///\"") } if uuid.Parse(path[2]) == nil { - return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs//\"") + return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/uploads//blobs|compressed-blobs///\"") } res := req.ResourceName @@ -522,11 +562,24 @@ func (f *CAS) Write(stream bsgrpc.ByteStream_WriteServer) (err error) { return status.Errorf(codes.InvalidArgument, "reached end of stream before the client finished writing") } + uncompressedBuf := buf.Bytes() + if path[3] == "compressed-blobs" { + if path[4] == "zstd" { + var err error + uncompressedBuf, err = f.decoder.DecodeAll(buf.Bytes(), nil) + if err != nil { + return status.Errorf(codes.InvalidArgument, "served bytes can't be decompressed: %v", err) + } + } else { + return status.Errorf(codes.InvalidArgument, "%s compressor isn't supported", path[4]) + } + } + f.mu.Lock() - f.blobs[dg] = buf.Bytes() + f.blobs[dg] = uncompressedBuf f.writes[dg]++ f.mu.Unlock() - cDg := digest.NewFromBlob(buf.Bytes()) + cDg := digest.NewFromBlob(uncompressedBuf) if dg != cDg { return status.Errorf(codes.InvalidArgument, "mismatched digest: received %s, computed %s", dg, cDg) } diff --git a/go/pkg/fakes/server.go b/go/pkg/fakes/server.go index 32bc57ab4..d2417affe 100644 --- a/go/pkg/fakes/server.go +++ b/go/pkg/fakes/server.go @@ -41,7 +41,10 @@ type Server struct { // NewServer creates a server that is ready to accept requests. func NewServer(t *testing.T) (s *Server, err error) { - cas := NewCAS() + cas, err := NewCAS() + if err != nil { + return nil, err + } ac := NewActionCache() s = &Server{Exec: NewExec(t, ac, cas), CAS: cas, ActionCache: ac} s.listener, err = net.Listen("tcp", ":0")