Skip to content

Commit

Permalink
Revert "perf(storage): remove protobuf's copy of data on unmarshalling (
Browse files Browse the repository at this point in the history
googleapis#9526)"

This reverts commit 81281c0.
Also updates grpc-go to use new default codec
  • Loading branch information
tritone committed Aug 22, 2024
1 parent 27f6ae6 commit a4beeb9
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 405 deletions.
2 changes: 1 addition & 1 deletion storage/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
google.golang.org/api v0.193.0
google.golang.org/genproto v0.0.0-20240814211410-ddb44dafa142
google.golang.org/genproto/googleapis/api v0.0.0-20240814211410-ddb44dafa142
google.golang.org/grpc v1.65.0
google.golang.org/grpc v1.67.0-dev.0.20240822173259-e4b09f111dd3
google.golang.org/protobuf v1.34.2
)

Expand Down
262 changes: 10 additions & 252 deletions storage/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ import (
"google.golang.org/api/option/internaloption"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
)

Expand Down Expand Up @@ -959,50 +956,12 @@ func (c *grpcStorageClient) RewriteObject(ctx context.Context, req *rewriteObjec
return r, nil
}

// bytesCodec is a grpc codec which permits receiving messages as either
// protobuf messages, or as raw []bytes.
type bytesCodec struct {
encoding.Codec
}

func (bytesCodec) Marshal(v any) ([]byte, error) {
vv, ok := v.(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
}
return proto.Marshal(vv)
}

func (bytesCodec) Unmarshal(data []byte, v any) error {
switch v := v.(type) {
case *[]byte:
// If gRPC could recycle the data []byte after unmarshaling (through
// buffer pools), we would need to make a copy here.
*v = data
return nil
case proto.Message:
return proto.Unmarshal(data, v)
default:
return fmt.Errorf("can not unmarshal type %T", v)
}
}

func (bytesCodec) Name() string {
// If this isn't "", then gRPC sets the content-subtype of the call to this
// value and we get errors.
return ""
}

func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRangeReaderParams, opts ...storageOption) (r *Reader, err error) {
ctx = trace.StartSpan(ctx, "cloud.google.com/go/storage.grpcStorageClient.NewRangeReader")
defer func() { trace.EndSpan(ctx, err) }()

s := callSettings(c.settings, opts...)

s.gax = append(s.gax, gax.WithGRPCOptions(
grpc.ForceCodec(bytesCodec{}),
))

if s.userProject != "" {
ctx = setUserProjectMetadata(ctx, s.userProject)
}
Expand All @@ -1018,8 +977,6 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
req.Generation = params.gen
}

var databuf []byte

// Define a function that initiates a Read with offset and length, assuming
// we have already read seen bytes.
reopen := func(seen int64) (*readStreamResponse, context.CancelFunc, error) {
Expand Down Expand Up @@ -1054,23 +1011,12 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
return err
}

// Receive the message into databuf as a wire-encoded message so we can
// use a custom decoder to avoid an extra copy at the protobuf layer.
err := stream.RecvMsg(&databuf)
msg, err = stream.Recv()
// These types of errors show up on the Recv call, rather than the
// initialization of the stream via ReadObject above.
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return ErrObjectNotExist
}
if err != nil {
return err
}
// Use a custom decoder that uses protobuf unmarshalling for all
// fields except the checksummed data.
// Subsequent receives in Read calls will skip all protobuf
// unmarshalling and directly read the content from the gRPC []byte
// response, since only the first call will contain other fields.
msg, err = readFullObjectResponse(databuf)

return err
}, s.retry, s.idempotent)
Expand Down Expand Up @@ -1129,7 +1075,6 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
leftovers: msg.GetChecksummedData().GetContent(),
settings: s,
zeroRange: params.length == 0,
databuf: databuf,
wantCRC: wantCRC,
checkCRC: checkCRC,
},
Expand Down Expand Up @@ -1525,7 +1470,6 @@ type gRPCReader struct {
stream storagepb.Storage_ReadObjectClient
reopen func(seen int64) (*readStreamResponse, context.CancelFunc, error)
leftovers []byte
databuf []byte
cancel context.CancelFunc
settings *settings
checkCRC bool // should we check the CRC?
Expand Down Expand Up @@ -1579,7 +1523,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
}

// Attempt to Recv the next message on the stream.
content, err := r.recv()
msg, err := r.recv()
if err != nil {
return 0, err
}
Expand All @@ -1591,6 +1535,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
// present in the response here.
// TODO: Figure out if we need to support decompressive transcoding
// https://cloud.google.com/storage/docs/transcoding.
content := msg.GetChecksummedData().GetContent()
n = copy(p[n:], content)
leftover := len(content) - n
if leftover > 0 {
Expand Down Expand Up @@ -1681,20 +1626,18 @@ func (r *gRPCReader) Close() error {
return nil
}

// recv attempts to Recv the next message on the stream and extract the object
// data that it contains. In the event that a retryable error is encountered,
// the stream will be closed, reopened, and RecvMsg again.
// This will attempt to Recv until one of the following is true:
// recv attempts to Recv the next message on the stream. In the event
// that a retryable error is encountered, the stream will be closed, reopened,
// and Recv again. This will attempt to Recv until one of the following is true:
//
// * Recv is successful
// * A non-retryable error is encountered
// * The Reader's context is canceled
//
// The last error received is the one that is returned, which could be from
// an attempt to reopen the stream.
func (r *gRPCReader) recv() ([]byte, error) {
err := r.stream.RecvMsg(&r.databuf)

func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
msg, err := r.stream.Recv()
var shouldRetry = ShouldRetry
if r.settings.retry != nil && r.settings.retry.shouldRetry != nil {
shouldRetry = r.settings.retry.shouldRetry
Expand All @@ -1704,195 +1647,10 @@ func (r *gRPCReader) recv() ([]byte, error) {
// reopen the stream, but will backoff if further attempts are necessary.
// Reopening the stream Recvs the first message, so if retrying is
// successful, the next logical chunk will be returned.
msg, err := r.reopenStream()
return msg.GetChecksummedData().GetContent(), err
msg, err = r.reopenStream()
}

if err != nil {
return nil, err
}

return readObjectResponseContent(r.databuf)
}

// ReadObjectResponse field and subfield numbers.
const (
checksummedDataField = protowire.Number(1)
checksummedDataContentField = protowire.Number(1)
checksummedDataCRC32CField = protowire.Number(2)
objectChecksumsField = protowire.Number(2)
contentRangeField = protowire.Number(3)
metadataField = protowire.Number(4)
)

// readObjectResponseContent returns the checksummed_data.content field of a
// ReadObjectResponse message, or an error if the message is invalid.
// This can be used on recvs of objects after the first recv, since only the
// first message will contain non-data fields.
func readObjectResponseContent(b []byte) ([]byte, error) {
checksummedData, err := readProtoBytes(b, checksummedDataField)
if err != nil {
return b, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", err)
}
content, err := readProtoBytes(checksummedData, checksummedDataContentField)
if err != nil {
return content, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", err)
}

return content, nil
}

// readFullObjectResponse returns the ReadObjectResponse that is encoded in the
// wire-encoded message buffer b, or an error if the message is invalid.
// This must be used on the first recv of an object as it may contain all fields
// of ReadObjectResponse, and we use or pass on those fields to the user.
// This function is essentially identical to proto.Unmarshal, except it aliases
// the data in the input []byte. If the proto library adds a feature to
// Unmarshal that does that, this function can be dropped.
func readFullObjectResponse(b []byte) (*storagepb.ReadObjectResponse, error) {
msg := &storagepb.ReadObjectResponse{}

// Loop over the entire message, extracting fields as we go. This does not
// handle field concatenation, in which the contents of a single field
// are split across multiple protobuf tags.
off := 0
for off < len(b) {
// Consume the next tag. This will tell us which field is next in the
// buffer, its type, and how much space it takes up.
fieldNum, fieldType, fieldLength := protowire.ConsumeTag(b[off:])
if fieldLength < 0 {
return nil, protowire.ParseError(fieldLength)
}
off += fieldLength

// Unmarshal the field according to its type. Only fields that are not
// nil will be present.
switch {
case fieldNum == checksummedDataField && fieldType == protowire.BytesType:
// The ChecksummedData field was found. Initialize the struct.
msg.ChecksummedData = &storagepb.ChecksummedData{}

// Get the bytes corresponding to the checksummed data.
fieldContent, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", protowire.ParseError(n))
}
off += n

// Get the nested fields. We need to do this manually as it contains
// the object content bytes.
contentOff := 0
for contentOff < len(fieldContent) {
gotNum, gotTyp, n := protowire.ConsumeTag(fieldContent[contentOff:])
if n < 0 {
return nil, protowire.ParseError(n)
}
contentOff += n

switch {
case gotNum == checksummedDataContentField && gotTyp == protowire.BytesType:
// Get the content bytes.
bytes, n := protowire.ConsumeBytes(fieldContent[contentOff:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", protowire.ParseError(n))
}
msg.ChecksummedData.Content = bytes
contentOff += n
case gotNum == checksummedDataCRC32CField && gotTyp == protowire.Fixed32Type:
v, n := protowire.ConsumeFixed32(fieldContent[contentOff:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Crc32C: %v", protowire.ParseError(n))
}
msg.ChecksummedData.Crc32C = &v
contentOff += n
default:
n = protowire.ConsumeFieldValue(gotNum, gotTyp, fieldContent[contentOff:])
if n < 0 {
return nil, protowire.ParseError(n)
}
contentOff += n
}
}
case fieldNum == objectChecksumsField && fieldType == protowire.BytesType:
// The field was found. Initialize the struct.
msg.ObjectChecksums = &storagepb.ObjectChecksums{}

// Get the bytes corresponding to the checksums.
bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ObjectChecksums: %v", protowire.ParseError(n))
}
off += n

// Unmarshal.
if err := proto.Unmarshal(bytes, msg.ObjectChecksums); err != nil {
return nil, err
}
case fieldNum == contentRangeField && fieldType == protowire.BytesType:
msg.ContentRange = &storagepb.ContentRange{}

bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ContentRange: %v", protowire.ParseError(n))
}
off += n

if err := proto.Unmarshal(bytes, msg.ContentRange); err != nil {
return nil, err
}
case fieldNum == metadataField && fieldType == protowire.BytesType:
msg.Metadata = &storagepb.Object{}

bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.Metadata: %v", protowire.ParseError(n))
}
off += n

if err := proto.Unmarshal(bytes, msg.Metadata); err != nil {
return nil, err
}
default:
fieldLength = protowire.ConsumeFieldValue(fieldNum, fieldType, b[off:])
if fieldLength < 0 {
return nil, fmt.Errorf("default: %v", protowire.ParseError(fieldLength))
}
off += fieldLength
}
}

return msg, nil
}

// readProtoBytes returns the contents of the protobuf field with number num
// and type bytes from a wire-encoded message. If the field cannot be found,
// the returned slice will be nil and no error will be returned.
//
// It does not handle field concatenation, in which the contents of a single field
// are split across multiple protobuf tags. Encoded data containing split fields
// of this form is technically permissable, but uncommon.
func readProtoBytes(b []byte, num protowire.Number) ([]byte, error) {
off := 0
for off < len(b) {
gotNum, gotTyp, n := protowire.ConsumeTag(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
if gotNum == num && gotTyp == protowire.BytesType {
b, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
return b, nil
}
n = protowire.ConsumeFieldValue(gotNum, gotTyp, b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
}
return nil, nil
return msg, err
}

// reopenStream "closes" the existing stream and attempts to reopen a stream and
Expand Down
Loading

0 comments on commit a4beeb9

Please sign in to comment.