diff --git a/codec.go b/codec.go index d3c2b35bf7e0..e840858b77b1 100644 --- a/codec.go +++ b/codec.go @@ -39,17 +39,11 @@ type baseCodec interface { // with encoding.GetCodec and if it is registered wraps it with newCodecV1Bridge // to turn it into an encoding.CodecV2. Returns nil otherwise. func getCodec(name string) encoding.CodecV2 { - codecV2 := encoding.GetCodecV2(name) - if codecV2 != nil { - return codecV2 - } - - codecV1 := encoding.GetCodec(name) - if codecV1 != nil { + if codecV1 := encoding.GetCodec(name); codecV1 != nil { return newCodecV1Bridge(codecV1) } - return nil + return encoding.GetCodecV2(name) } func newCodecV0Bridge(c Codec) baseCodec { diff --git a/codec_test.go b/codec_test.go index 4e1c625eacc5..f2f88e9636a8 100644 --- a/codec_test.go +++ b/codec_test.go @@ -26,7 +26,7 @@ import ( ) func (s) TestGetCodecForProtoIsNotNil(t *testing.T) { - if encoding.GetCodec(proto.Name) == nil { + if encoding.GetCodecV2(proto.Name) == nil { t.Fatalf("encoding.GetCodec(%q) must not be nil by default", proto.Name) } } diff --git a/encoding/encoding.go b/encoding/encoding.go index 5ebf88d7147f..11d0ae142c42 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -94,7 +94,7 @@ type Codec interface { Name() string } -var registeredCodecs = make(map[string]Codec) +var registeredCodecs = make(map[string]any) // RegisterCodec registers the provided Codec for use with all gRPC clients and // servers. @@ -126,5 +126,6 @@ func RegisterCodec(codec Codec) { // // The content-subtype is expected to be lowercase. func GetCodec(contentSubtype string) Codec { - return registeredCodecs[contentSubtype] + c, _ := registeredCodecs[contentSubtype].(Codec) + return c } diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go index 4a4fec33498a..9ac59e461633 100644 --- a/encoding/encoding_test.go +++ b/encoding/encoding_test.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -90,18 +91,18 @@ type errProtoCodec struct { decodingErr error } -func (c *errProtoCodec) Marshal(v any) ([]byte, error) { +func (c *errProtoCodec) Marshal(v any) (mem.BufferSlice, error) { if c.encodingErr != nil { return nil, c.encodingErr } - return encoding.GetCodec(proto.Name).Marshal(v) + return encoding.GetCodecV2(proto.Name).Marshal(v) } -func (c *errProtoCodec) Unmarshal(data []byte, v any) error { +func (c *errProtoCodec) Unmarshal(data mem.BufferSlice, v any) error { if c.decodingErr != nil { return c.decodingErr } - return encoding.GetCodec(proto.Name).Unmarshal(data, v) + return encoding.GetCodecV2(proto.Name).Unmarshal(data, v) } func (c *errProtoCodec) Name() string { @@ -118,7 +119,7 @@ func (s) TestEncodeDoesntPanicOnServer(t *testing.T) { ec := &errProtoCodec{name: t.Name(), encodingErr: encodingErr} // Start a server with the above codec. - backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec)) + backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(ec)) defer backend.Stop() // Create a channel to the above server. @@ -154,7 +155,7 @@ func (s) TestDecodeDoesntPanicOnServer(t *testing.T) { ec := &errProtoCodec{name: t.Name(), decodingErr: decodingErr} // Start a server with the above codec. - backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(ec)) + backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(ec)) defer backend.Stop() // Create a channel to the above server. Since we do not specify any codec @@ -206,7 +207,7 @@ func (s) TestEncodeDoesntPanicOnClient(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() client := testgrpc.NewTestServiceClient(cc) - _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)) + _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)) if err == nil || !strings.Contains(err.Error(), encodingErr.Error()) { t.Fatalf("RPC failed with error: %v, want: %v", err, encodingErr) } @@ -214,7 +215,7 @@ func (s) TestEncodeDoesntPanicOnClient(t *testing.T) { // Configure the codec on the client to not return errors anymore and expect // the RPC to succeed. ec.encodingErr = nil - if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil { + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)); err != nil { t.Fatalf("RPC failed with error: %v", err) } } @@ -242,7 +243,7 @@ func (s) TestDecodeDoesntPanicOnClient(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() client := testgrpc.NewTestServiceClient(cc) - _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)) + _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)) if err == nil || !strings.Contains(err.Error(), decodingErr.Error()) { t.Fatalf("RPC failed with error: %v, want: %v", err, decodingErr) } @@ -250,7 +251,7 @@ func (s) TestDecodeDoesntPanicOnClient(t *testing.T) { // Configure the codec on the client to not return errors anymore and expect // the RPC to succeed. ec.decodingErr = nil - if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(ec)); err != nil { + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(ec)); err != nil { t.Fatalf("RPC failed with error: %v", err) } } @@ -265,14 +266,14 @@ type countingProtoCodec struct { unmarshalCount int32 } -func (p *countingProtoCodec) Marshal(v any) ([]byte, error) { +func (p *countingProtoCodec) Marshal(v any) (mem.BufferSlice, error) { atomic.AddInt32(&p.marshalCount, 1) - return encoding.GetCodec(proto.Name).Marshal(v) + return encoding.GetCodecV2(proto.Name).Marshal(v) } -func (p *countingProtoCodec) Unmarshal(data []byte, v any) error { +func (p *countingProtoCodec) Unmarshal(data mem.BufferSlice, v any) error { atomic.AddInt32(&p.unmarshalCount, 1) - return encoding.GetCodec(proto.Name).Unmarshal(data, v) + return encoding.GetCodecV2(proto.Name).Unmarshal(data, v) } func (p *countingProtoCodec) Name() string { @@ -284,7 +285,7 @@ func (p *countingProtoCodec) Name() string { func (s) TestForceServerCodec(t *testing.T) { // Create an server with the counting proto codec. codec := &countingProtoCodec{name: t.Name()} - backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodec(codec)) + backend := stubserver.StartTestService(t, nil, grpc.ForceServerCodecV2(codec)) defer backend.Stop() // Create a channel to the above server. @@ -317,7 +318,7 @@ func (s) TestForceServerCodec(t *testing.T) { // renameProtoCodec wraps the proto codec and allows customizing the Name(). type renameProtoCodec struct { - encoding.Codec + encoding.CodecV2 name string } @@ -356,9 +357,9 @@ func (s) TestForceCodecName(t *testing.T) { // Force the use of the custom codec on the client with the ForceCodec call // option. Confirm the name is converted to lowercase before transmitting. - codec := &renameProtoCodec{Codec: encoding.GetCodec(proto.Name), name: t.Name()} + codec := &renameProtoCodec{CodecV2: encoding.GetCodecV2(proto.Name), name: t.Name()} wantContentTypeCh <- []string{fmt.Sprintf("application/grpc+%s", strings.ToLower(t.Name()))} - if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodec(codec)); err != nil { + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}, grpc.ForceCodecV2(codec)); err != nil { t.Fatalf("ss.Client.EmptyCall(_, _) = _, %v; want _, nil", err) } } diff --git a/encoding/encoding_v2.go b/encoding/encoding_v2.go index e209f6f1ab62..074c5e234a7b 100644 --- a/encoding/encoding_v2.go +++ b/encoding/encoding_v2.go @@ -43,8 +43,6 @@ type CodecV2 interface { Name() string } -var registeredV2Codecs = make(map[string]CodecV2) - // RegisterCodecV2 registers the provided CodecV2 for use with all gRPC clients and // servers. // @@ -70,7 +68,7 @@ func RegisterCodecV2(codec CodecV2) { panic("cannot register CodecV2 with empty string result for Name()") } contentSubtype := strings.ToLower(codec.Name()) - registeredV2Codecs[contentSubtype] = codec + registeredCodecs[contentSubtype] = codec } // GetCodecV2 gets a registered CodecV2 by content-subtype, or nil if no CodecV2 is @@ -78,5 +76,6 @@ func RegisterCodecV2(codec CodecV2) { // // The content-subtype is expected to be lowercase. func GetCodecV2(contentSubtype string) CodecV2 { - return registeredV2Codecs[contentSubtype] + c, _ := registeredCodecs[contentSubtype].(CodecV2) + return c } diff --git a/encoding/proto/proto.go b/encoding/proto/proto.go index 66d5cdf03ec5..ceec319dd2fb 100644 --- a/encoding/proto/proto.go +++ b/encoding/proto/proto.go @@ -1,6 +1,6 @@ /* * - * Copyright 2018 gRPC authors. + * Copyright 2024 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import ( "fmt" "google.golang.org/grpc/encoding" + "google.golang.org/grpc/mem" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/protoadapt" ) @@ -32,28 +33,51 @@ import ( const Name = "proto" func init() { - encoding.RegisterCodec(codec{}) + encoding.RegisterCodecV2(&codecV2{}) } -// codec is a Codec implementation with protobuf. It is the default codec for gRPC. -type codec struct{} +// codec is a CodecV2 implementation with protobuf. It is the default codec for +// gRPC. +type codecV2 struct{} -func (codec) Marshal(v any) ([]byte, error) { +func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) { vv := messageV2Of(v) if vv == nil { - return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) + return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v) } - return proto.Marshal(vv) + size := proto.Size(vv) + if mem.IsBelowBufferPoolingThreshold(size) { + buf, err := proto.Marshal(vv) + if err != nil { + return nil, err + } + data = append(data, mem.SliceBuffer(buf)) + } else { + pool := mem.DefaultBufferPool() + buf := pool.Get(size) + if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil { + pool.Put(buf) + return nil, err + } + data = append(data, mem.NewBuffer(buf, pool)) + } + + return data, nil } -func (codec) Unmarshal(data []byte, v any) error { +func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) { vv := messageV2Of(v) if vv == nil { return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) } - return proto.Unmarshal(data, vv) + buf := data.MaterializeToBuffer(mem.DefaultBufferPool()) + defer buf.Free() + // TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not + // really possible without a major overhaul of the proto package, but the + // vtprotobuf library may be able to support this. + return proto.Unmarshal(buf.ReadOnlyData(), vv) } func messageV2Of(v any) proto.Message { @@ -67,6 +91,6 @@ func messageV2Of(v any) proto.Message { return nil } -func (codec) Name() string { +func (c *codecV2) Name() string { return Name } diff --git a/encoding/proto/proto_benchmark_test.go b/encoding/proto/proto_benchmark_test.go index 8b4cec756342..58007dcbfecc 100644 --- a/encoding/proto/proto_benchmark_test.go +++ b/encoding/proto/proto_benchmark_test.go @@ -68,7 +68,7 @@ func BenchmarkProtoCodec(b *testing.B) { protoStructs := setupBenchmarkProtoCodecInputs(s) name := fmt.Sprintf("MinPayloadSize:%v/SetParallelism(%v)", s, p) b.Run(name, func(b *testing.B) { - codec := &codec{} + codec := &codecV2{} b.SetParallelism(p) b.RunParallel(func(pb *testing.PB) { benchmarkProtoCodec(codec, protoStructs, pb, b) @@ -78,7 +78,7 @@ func BenchmarkProtoCodec(b *testing.B) { } } -func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) { +func benchmarkProtoCodec(codec *codecV2, protoStructs []proto.Message, pb *testing.PB, b *testing.B) { counter := 0 for pb.Next() { counter++ @@ -87,7 +87,7 @@ func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing } } -func fastMarshalAndUnmarshal(codec encoding.Codec, protoStruct proto.Message, b *testing.B) { +func fastMarshalAndUnmarshal(codec encoding.CodecV2, protoStruct proto.Message, b *testing.B) { marshaledBytes, err := codec.Marshal(protoStruct) if err != nil { b.Errorf("codec.Marshal(_) returned an error") diff --git a/encoding/proto/proto_test.go b/encoding/proto/proto_test.go index d017eb8ec30d..117f3cb97fe8 100644 --- a/encoding/proto/proto_test.go +++ b/encoding/proto/proto_test.go @@ -25,10 +25,11 @@ import ( "google.golang.org/grpc/encoding" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/mem" pb "google.golang.org/grpc/test/codec_perf" ) -func marshalAndUnmarshal(t *testing.T, codec encoding.Codec, expectedBody []byte) { +func marshalAndUnmarshal(t *testing.T, codec encoding.CodecV2, expectedBody []byte) { p := &pb.Buffer{} p.Body = expectedBody @@ -55,7 +56,7 @@ func Test(t *testing.T) { } func (s) TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) { - marshalAndUnmarshal(t, codec{}, []byte{1, 2, 3}) + marshalAndUnmarshal(t, &codecV2{}, []byte{1, 2, 3}) } // Try to catch possible race conditions around use of pools @@ -75,7 +76,7 @@ func (s) TestConcurrentUsage(t *testing.T) { } var wg sync.WaitGroup - codec := codec{} + codec := &codecV2{} for i := 0; i < numGoRoutines; i++ { wg.Add(1) @@ -93,8 +94,8 @@ func (s) TestConcurrentUsage(t *testing.T) { // TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get // stomped on during reuse of a proto.Buffer. func (s) TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) { - codec1 := codec{} - codec2 := codec{} + codec1 := &codecV2{} + codec2 := &codecV2{} expectedBody1 := []byte{1, 2, 3} expectedBody2 := []byte{4, 5, 6} @@ -102,7 +103,7 @@ func (s) TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) { proto1 := pb.Buffer{Body: expectedBody1} proto2 := pb.Buffer{Body: expectedBody2} - var m1, m2 []byte + var m1, m2 mem.BufferSlice var err error if m1, err = codec1.Marshal(&proto1); err != nil { diff --git a/encoding/proto/proto_v2.go b/encoding/proto/proto_v2.go deleted file mode 100644 index 367a3cd66832..000000000000 --- a/encoding/proto/proto_v2.go +++ /dev/null @@ -1,81 +0,0 @@ -/* - * - * Copyright 2024 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package proto - -import ( - "fmt" - - "google.golang.org/grpc/encoding" - "google.golang.org/grpc/mem" - "google.golang.org/protobuf/proto" -) - -func init() { - encoding.RegisterCodecV2(&codecV2{}) -} - -// codec is a CodecV2 implementation with protobuf. It is the default codec for -// gRPC. -type codecV2 struct{} - -var _ encoding.CodecV2 = (*codecV2)(nil) - -func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) { - vv := messageV2Of(v) - if vv == nil { - return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v) - } - - size := proto.Size(vv) - if mem.IsBelowBufferPoolingThreshold(size) { - buf, err := proto.Marshal(vv) - if err != nil { - return nil, err - } - data = append(data, mem.SliceBuffer(buf)) - } else { - pool := mem.DefaultBufferPool() - buf := pool.Get(size) - if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil { - pool.Put(buf) - return nil, err - } - data = append(data, mem.NewBuffer(buf, pool)) - } - - return data, nil -} - -func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) { - vv := messageV2Of(v) - if vv == nil { - return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) - } - - buf := data.MaterializeToBuffer(mem.DefaultBufferPool()) - defer buf.Free() - // TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not - // really possible without a major overhaul of the proto package, but the - // vtprotobuf library may be able to support this. - return proto.Unmarshal(buf.ReadOnlyData(), vv) -} - -func (c *codecV2) Name() string { - return Name -}