From c5040f9d5bb185dda254e89d9898abae37357c40 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 6 Nov 2023 17:21:32 -0500 Subject: [PATCH 1/6] Add Schema field to Spec for introspection New field Schema of type any on Spec objects. For proto based schemas the type will be of protoreflect.MethodDescriptor. This allows for easy introspection to interceptors. --- client.go | 2 + client_ext_test.go | 121 ++++++++++++++++++ cmd/protoc-gen-connect-go/main.go | 45 +++---- connect.go | 8 +- handler.go | 2 + .../v1/collidev1connect/collide.connect.go | 8 +- .../v1/importv1connect/import.connect.go | 2 +- .../ping/v1/pingv1connect/ping.connect.go | 28 ++-- option.go | 22 ++++ 9 files changed, 194 insertions(+), 44 deletions(-) diff --git a/client.go b/client.go index 38bb541b..2abcaa5e 100644 --- a/client.go +++ b/client.go @@ -189,6 +189,7 @@ type clientConfig struct { URL *url.URL Protocol protocol Procedure string + Schema any CompressMinBytes int Interceptor Interceptor CompressionPools map[string]*compressionPool @@ -251,6 +252,7 @@ func (c *clientConfig) newSpec(t StreamType) Spec { return Spec{ StreamType: t, Procedure: c.Procedure, + Schema: c.Schema, IsClient: true, IdempotencyLevel: c.IdempotencyLevel, } diff --git a/client_ext_test.go b/client_ext_test.go index ce799958..e5f6995a 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -17,6 +17,7 @@ package connect_test import ( "context" "errors" + "fmt" "net/http" "strings" "testing" @@ -26,6 +27,7 @@ import ( pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp/memhttptest" + "google.golang.org/protobuf/reflect/protoreflect" ) func TestNewClient_InitFailure(t *testing.T) { @@ -186,6 +188,81 @@ func TestGetNotModified(t *testing.T) { assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod()) } +func TestSpecSchema(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + connect.WithInterceptors(&assertSchemaInterceptor{t}), + )) + server := memhttptest.NewServer(t, mux) + testcases := []struct { + name string + opts []connect.ClientOption + }{{ + name: connect.ProtocolConnect, + }, { + name: connect.ProtocolGRPC, + opts: []connect.ClientOption{ + connect.WithGRPC(), + }, + }, { + name: connect.ProtocolGRPCWeb, + opts: []connect.ClientOption{ + connect.WithGRPC(), + }, + }} + for _, testcase := range testcases { + testcase := testcase + t.Run(testcase.name, func(t *testing.T) { + ctx := context.Background() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + connect.WithClientOptions(testcase.opts...), + connect.WithInterceptors(&assertSchemaInterceptor{t}), + ) + t.Parallel() + t.Run("unary", func(t *testing.T) { + unaryReq := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, unaryReq) + assert.Nil(t, err) + text := strings.Repeat(".", 256) + r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) + assert.Nil(t, err) + assert.Equal(t, r.Msg.Text, text) + }) + t.Run("client_stream", func(t *testing.T) { + clientStream := client.Sum(ctx) + t.Cleanup(func() { + _, closeErr := clientStream.CloseAndReceive() + assert.Nil(t, closeErr) + }) + assert.NotZero(t, clientStream.Spec().Schema) + err := clientStream.Send(&pingv1.SumRequest{}) + assert.Nil(t, err) + }) + t.Run("server_stream", func(t *testing.T) { + serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) + t.Cleanup(func() { + assert.Nil(t, serverStream.Close()) + }) + assert.Nil(t, err) + }) + t.Run("bidi_stream", func(t *testing.T) { + bidiStream := client.CumSum(ctx) + t.Cleanup(func() { + assert.Nil(t, bidiStream.CloseRequest()) + assert.Nil(t, bidiStream.CloseResponse()) + }) + assert.NotZero(t, bidiStream.Spec().Schema) + err := bidiStream.Send(&pingv1.CumSumRequest{}) + assert.Nil(t, err) + }) + }) + } +} + type notModifiedPingServer struct { pingv1connect.UnimplementedPingServiceHandler @@ -233,3 +310,47 @@ func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandl return next(ctx, conn) } } + +type assertSchemaInterceptor struct { + tb testing.TB +} + +func (a *assertSchemaInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + if !assert.NotNil(a.tb, req.Spec().Schema) { + return nil, fmt.Errorf("nil spec") + } + methodDesc, ok := req.Spec().Schema.(protoreflect.MethodDescriptor) + assert.True(a.tb, ok) + procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) + assert.Equal(a.tb, procedure, req.Spec().Procedure) + return next(ctx, req) + } +} + +func (a *assertSchemaInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + conn := next(ctx, spec) + if !assert.NotNil(a.tb, spec.Schema) { + return conn + } + methodDescriptor, ok := spec.Schema.(protoreflect.MethodDescriptor) + assert.True(a.tb, ok) + procedure := fmt.Sprintf("/%s/%s", methodDescriptor.Parent().FullName(), methodDescriptor.Name()) + assert.Equal(a.tb, procedure, spec.Procedure) + return conn + } +} + +func (a *assertSchemaInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + if !assert.NotNil(a.tb, conn.Spec().Schema) { + return fmt.Errorf("nil spec") + } + methodDesc, ok := conn.Spec().Schema.(protoreflect.MethodDescriptor) + assert.True(a.tb, ok) + procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) + assert.Equal(a.tb, procedure, conn.Spec().Procedure) + return next(ctx, conn) + } +} diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index bde19f51..4fda5b2a 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -102,17 +102,6 @@ func main() { ) } -func needsWithIdempotency(file *protogen.File) bool { - for _, service := range file.Services { - for _, method := range service.Methods { - if methodIdempotency(method) != connect.IdempotencyUnknown { - return true - } - } - } - return false -} - func generate(plugin *protogen.Plugin, file *protogen.File) { if len(file.Services) == 0 { return @@ -136,7 +125,7 @@ func generate(plugin *protogen.Plugin, file *protogen.File) { generatePreamble(generatedFile, file) generateServiceNameConstants(generatedFile, file.Services) for _, service := range file.Services { - generateService(generatedFile, service) + generateService(generatedFile, file, service) } } @@ -180,11 +169,7 @@ func generatePreamble(g *protogen.GeneratedFile, file *protogen.File) { "is not defined, this code was generated with a version of connect newer than the one ", "compiled into your binary. You can fix the problem by either regenerating this code ", "with an older version of connect or updating the connect version compiled into your binary.") - if needsWithIdempotency(file) { - g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_7_0")) - } else { - g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion0_1_0")) - } + g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_13_0")) g.P() } @@ -225,12 +210,12 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge g.P() } -func generateService(g *protogen.GeneratedFile, service *protogen.Service) { +func generateService(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service) { names := newNames(service) generateClientInterface(g, service, names) - generateClientImplementation(g, service, names) + generateClientImplementation(g, file, service, names) generateServerInterface(g, service, names) - generateServerConstructor(g, service, names) + generateServerConstructor(g, file, service, names) generateUnimplementedServerImplementation(g, service, names) } @@ -255,7 +240,7 @@ func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Servic g.P() } -func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service, names names) { clientOption := connectPackage.Ident("ClientOption") // Client constructor. @@ -283,17 +268,19 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S ) g.P("httpClient,") g.P(`baseURL + `, procedureConstName(method), `,`) + g.P(connectPackage.Ident("WithSchema"), "(", + g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`, + `.Methods().ByName("`, method.Desc.Name(), `")),`) idempotency := methodIdempotency(method) switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") - g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") case connect.IdempotencyIdempotent: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),") - g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") case connect.IdempotencyUnknown: - g.P("opts...,") } + g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") g.P("),") } g.P("}") @@ -390,7 +377,7 @@ func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Servic g.P() } -func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service, names names) { wrapComments(g, names.ServerConstructor, " builds an HTTP handler from the service implementation.", " It returns the path on which to mount the handler and the handler itself.") g.P("//") @@ -419,16 +406,18 @@ func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Serv } g.P(procedureConstName(method), `,`) g.P("svc.", method.GoName, ",") + g.P(connectPackage.Ident("WithSchema"), "(", + g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`, + `.Methods().ByName("`, method.Desc.Name(), `")),`) switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") - g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") case connect.IdempotencyIdempotent: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),") - g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") case connect.IdempotencyUnknown: - g.P("opts...,") } + g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") g.P(")") } g.P(`return "/`, service.Desc.FullName(), `/", `, httpPackage.Ident("HandlerFunc"), `(func(w `, httpPackage.Ident("ResponseWriter"), `, r *`, httpPackage.Ident("Request"), `){`) diff --git a/connect.go b/connect.go index c7c41d38..622852fb 100644 --- a/connect.go +++ b/connect.go @@ -38,9 +38,10 @@ const Version = "1.13.0-dev" // These constants are used in compile-time handshakes with connect's generated // code. const ( - IsAtLeastVersion0_0_1 = true - IsAtLeastVersion0_1_0 = true - IsAtLeastVersion1_7_0 = true + IsAtLeastVersion0_0_1 = true + IsAtLeastVersion0_1_0 = true + IsAtLeastVersion1_7_0 = true + IsAtLeastVersion1_13_0 = true ) // StreamType describes whether the client, server, neither, or both is @@ -314,6 +315,7 @@ type HTTPClient interface { // fully-qualified Procedure corresponding to each RPC in your schema. type Spec struct { StreamType StreamType + Schema any // for protobuf RPCs, a protoreflect.MethodDescriptor Procedure string // for example, "/acme.foo.v1.FooService/Bar" IsClient bool // otherwise we're in a handler IdempotencyLevel IdempotencyLevel diff --git a/handler.go b/handler.go index 43bfe973..b207ac66 100644 --- a/handler.go +++ b/handler.go @@ -246,6 +246,7 @@ type handlerConfig struct { CompressMinBytes int Interceptor Interceptor Procedure string + Schema any HandleGRPC bool HandleGRPCWeb bool RequireConnectProtocolHeader bool @@ -279,6 +280,7 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler func (c *handlerConfig) newSpec() Spec { return Spec{ Procedure: c.Procedure, + Schema: c.Schema, StreamType: c.StreamType, IdempotencyLevel: c.IdempotencyLevel, } diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index 20223410..4b21ff23 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -32,7 +32,7 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion0_1_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // CollideServiceName is the fully-qualified name of the CollideService service. @@ -69,7 +69,8 @@ func NewCollideServiceClient(httpClient connect.HTTPClient, baseURL string, opts _import: connect.NewClient[v1.ImportRequest, v1.ImportResponse]( httpClient, baseURL+CollideServiceImportProcedure, - opts..., + connect.WithSchema(v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService").Methods().ByName("Import")), + connect.WithClientOptions(opts...), ), } } @@ -98,7 +99,8 @@ func NewCollideServiceHandler(svc CollideServiceHandler, opts ...connect.Handler collideServiceImportHandler := connect.NewUnaryHandler( CollideServiceImportProcedure, svc.Import, - opts..., + connect.WithSchema(v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService").Methods().ByName("Import")), + connect.WithHandlerOptions(opts...), ) return "/connect.collide.v1.CollideService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { diff --git a/internal/gen/connect/import/v1/importv1connect/import.connect.go b/internal/gen/connect/import/v1/importv1connect/import.connect.go index d56d3efa..cb3f2803 100644 --- a/internal/gen/connect/import/v1/importv1connect/import.connect.go +++ b/internal/gen/connect/import/v1/importv1connect/import.connect.go @@ -30,7 +30,7 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion0_1_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // ImportServiceName is the fully-qualified name of the ImportService service. diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index 036539a8..84d01aa6 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -37,7 +37,7 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion1_7_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // PingServiceName is the fully-qualified name of the PingService service. @@ -91,28 +91,33 @@ func NewPingServiceClient(httpClient connect.HTTPClient, baseURL string, opts .. ping: connect.NewClient[v1.PingRequest, v1.PingResponse]( httpClient, baseURL+PingServicePingProcedure, + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Ping")), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithClientOptions(opts...), ), fail: connect.NewClient[v1.FailRequest, v1.FailResponse]( httpClient, baseURL+PingServiceFailProcedure, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Fail")), + connect.WithClientOptions(opts...), ), sum: connect.NewClient[v1.SumRequest, v1.SumResponse]( httpClient, baseURL+PingServiceSumProcedure, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Sum")), + connect.WithClientOptions(opts...), ), countUp: connect.NewClient[v1.CountUpRequest, v1.CountUpResponse]( httpClient, baseURL+PingServiceCountUpProcedure, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CountUp")), + connect.WithClientOptions(opts...), ), cumSum: connect.NewClient[v1.CumSumRequest, v1.CumSumResponse]( httpClient, baseURL+PingServiceCumSumProcedure, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CumSum")), + connect.WithClientOptions(opts...), ), } } @@ -174,28 +179,33 @@ func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption pingServicePingHandler := connect.NewUnaryHandler( PingServicePingProcedure, svc.Ping, + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Ping")), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithHandlerOptions(opts...), ) pingServiceFailHandler := connect.NewUnaryHandler( PingServiceFailProcedure, svc.Fail, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Fail")), + connect.WithHandlerOptions(opts...), ) pingServiceSumHandler := connect.NewClientStreamHandler( PingServiceSumProcedure, svc.Sum, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Sum")), + connect.WithHandlerOptions(opts...), ) pingServiceCountUpHandler := connect.NewServerStreamHandler( PingServiceCountUpProcedure, svc.CountUp, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CountUp")), + connect.WithHandlerOptions(opts...), ) pingServiceCumSumHandler := connect.NewBidiStreamHandler( PingServiceCumSumProcedure, svc.CumSum, - opts..., + connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CumSum")), + connect.WithHandlerOptions(opts...), ) return "/connect.ping.v1.PingService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { diff --git a/option.go b/option.go index 1eb8d7c3..c5d15f05 100644 --- a/option.go +++ b/option.go @@ -184,6 +184,16 @@ type Option interface { HandlerOption } +// WithSchema provides a parsed representation of the schema for an RPC to a +// client or handler. The supplied schema is exposed as [Spec.Schema]. This +// option is typically added by generated code. +// +// For services using protobuf schemas, the supplied schema should be a +// [protoreflect.MethodDescriptor]. +func WithSchema(schema any) Option { + return &schemaOption{Schema: schema} +} + // WithCodec registers a serialization method with a client or handler. // Handlers may have multiple codecs registered, and use whichever the client // chooses. Clients may only have a single codec. @@ -328,6 +338,18 @@ func WithOptions(options ...Option) Option { return &optionsOption{options} } +type schemaOption struct { + Schema any +} + +func (o *schemaOption) applyToClient(config *clientConfig) { + config.Schema = o.Schema +} + +func (o *schemaOption) applyToHandler(config *handlerConfig) { + config.Schema = o.Schema +} + type clientOptionsOption struct { options []ClientOption } From 817b5142fe78adf1fab0b5d9dca80e78e5cff930 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 6 Nov 2023 17:46:14 -0500 Subject: [PATCH 2/6] Use serviceDescriptor var --- cmd/protoc-gen-connect-go/main.go | 18 +++++++++------ .../v1/collidev1connect/collide.connect.go | 6 +++-- .../v1/importv1connect/import.connect.go | 2 -- .../ping/v1/pingv1connect/ping.connect.go | 22 ++++++++++--------- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index 4fda5b2a..fe4d0f6c 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -258,7 +258,11 @@ func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File } g.P("func ", names.ClientConstructor, " (httpClient ", connectPackage.Ident("HTTPClient"), ", baseURL string, opts ...", clientOption, ") ", names.Client, " {") - g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) + if len(service.Methods) > 0 { + g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) + g.P("serviceDescriptor := ", g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`) + } g.P("return &", names.ClientImpl, "{") for _, method := range service.Methods { g.P(unexport(method.GoName), ": ", @@ -269,9 +273,7 @@ func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File g.P("httpClient,") g.P(`baseURL + `, procedureConstName(method), `,`) g.P(connectPackage.Ident("WithSchema"), "(", - g.QualifiedGoIdent(file.GoDescriptorIdent), - `.Services().ByName("`, service.Desc.Name(), `")`, - `.Methods().ByName("`, method.Desc.Name(), `")),`) + `serviceDescriptor.Methods().ByName("`, method.Desc.Name(), `")),`) idempotency := methodIdempotency(method) switch idempotency { case connect.IdempotencyNoSideEffects: @@ -390,6 +392,10 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s handlerOption := connectPackage.Ident("HandlerOption") g.P("func ", names.ServerConstructor, "(svc ", names.Server, ", opts ...", handlerOption, ") (string, ", httpPackage.Ident("Handler"), ") {") + if len(service.Methods) > 0 { + g.P("serviceDescriptor := ", g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`) + } for _, method := range service.Methods { isStreamingServer := method.Desc.IsStreamingServer() isStreamingClient := method.Desc.IsStreamingClient() @@ -407,9 +413,7 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s g.P(procedureConstName(method), `,`) g.P("svc.", method.GoName, ",") g.P(connectPackage.Ident("WithSchema"), "(", - g.QualifiedGoIdent(file.GoDescriptorIdent), - `.Services().ByName("`, service.Desc.Name(), `")`, - `.Methods().ByName("`, method.Desc.Name(), `")),`) + `serviceDescriptor.Methods().ByName("`, method.Desc.Name(), `")),`) switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index 4b21ff23..9bee99e4 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -65,11 +65,12 @@ type CollideServiceClient interface { // http://api.acme.com or https://acme.com/grpc). func NewCollideServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) CollideServiceClient { baseURL = strings.TrimRight(baseURL, "/") + serviceDescriptor := v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") return &collideServiceClient{ _import: connect.NewClient[v1.ImportRequest, v1.ImportResponse]( httpClient, baseURL+CollideServiceImportProcedure, - connect.WithSchema(v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService").Methods().ByName("Import")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Import")), connect.WithClientOptions(opts...), ), } @@ -96,10 +97,11 @@ type CollideServiceHandler interface { // By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf // and JSON codecs. They also support gzip compression. func NewCollideServiceHandler(svc CollideServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + serviceDescriptor := v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") collideServiceImportHandler := connect.NewUnaryHandler( CollideServiceImportProcedure, svc.Import, - connect.WithSchema(v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService").Methods().ByName("Import")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Import")), connect.WithHandlerOptions(opts...), ) return "/connect.collide.v1.CollideService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/gen/connect/import/v1/importv1connect/import.connect.go b/internal/gen/connect/import/v1/importv1connect/import.connect.go index cb3f2803..043ff35f 100644 --- a/internal/gen/connect/import/v1/importv1connect/import.connect.go +++ b/internal/gen/connect/import/v1/importv1connect/import.connect.go @@ -22,7 +22,6 @@ import ( connect "connectrpc.com/connect" _ "connectrpc.com/connect/internal/gen/connect/import/v1" http "net/http" - strings "strings" ) // This is a compile-time assertion to ensure that this generated file and the connect package are @@ -49,7 +48,6 @@ type ImportServiceClient interface { // The URL supplied here should be the base URL for the Connect or gRPC server (for example, // http://api.acme.com or https://acme.com/grpc). func NewImportServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ImportServiceClient { - baseURL = strings.TrimRight(baseURL, "/") return &importServiceClient{} } diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index 84d01aa6..e60a15d1 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -87,36 +87,37 @@ type PingServiceClient interface { // http://api.acme.com or https://acme.com/grpc). func NewPingServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) PingServiceClient { baseURL = strings.TrimRight(baseURL, "/") + serviceDescriptor := v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") return &pingServiceClient{ ping: connect.NewClient[v1.PingRequest, v1.PingResponse]( httpClient, baseURL+PingServicePingProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Ping")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Ping")), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithClientOptions(opts...), ), fail: connect.NewClient[v1.FailRequest, v1.FailResponse]( httpClient, baseURL+PingServiceFailProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Fail")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Fail")), connect.WithClientOptions(opts...), ), sum: connect.NewClient[v1.SumRequest, v1.SumResponse]( httpClient, baseURL+PingServiceSumProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Sum")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Sum")), connect.WithClientOptions(opts...), ), countUp: connect.NewClient[v1.CountUpRequest, v1.CountUpResponse]( httpClient, baseURL+PingServiceCountUpProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CountUp")), + connect.WithSchema(serviceDescriptor.Methods().ByName("CountUp")), connect.WithClientOptions(opts...), ), cumSum: connect.NewClient[v1.CumSumRequest, v1.CumSumResponse]( httpClient, baseURL+PingServiceCumSumProcedure, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CumSum")), + connect.WithSchema(serviceDescriptor.Methods().ByName("CumSum")), connect.WithClientOptions(opts...), ), } @@ -176,35 +177,36 @@ type PingServiceHandler interface { // By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf // and JSON codecs. They also support gzip compression. func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + serviceDescriptor := v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") pingServicePingHandler := connect.NewUnaryHandler( PingServicePingProcedure, svc.Ping, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Ping")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Ping")), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithHandlerOptions(opts...), ) pingServiceFailHandler := connect.NewUnaryHandler( PingServiceFailProcedure, svc.Fail, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Fail")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Fail")), connect.WithHandlerOptions(opts...), ) pingServiceSumHandler := connect.NewClientStreamHandler( PingServiceSumProcedure, svc.Sum, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("Sum")), + connect.WithSchema(serviceDescriptor.Methods().ByName("Sum")), connect.WithHandlerOptions(opts...), ) pingServiceCountUpHandler := connect.NewServerStreamHandler( PingServiceCountUpProcedure, svc.CountUp, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CountUp")), + connect.WithSchema(serviceDescriptor.Methods().ByName("CountUp")), connect.WithHandlerOptions(opts...), ) pingServiceCumSumHandler := connect.NewBidiStreamHandler( PingServiceCumSumProcedure, svc.CumSum, - connect.WithSchema(v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService").Methods().ByName("CumSum")), + connect.WithSchema(serviceDescriptor.Methods().ByName("CumSum")), connect.WithHandlerOptions(opts...), ) return "/connect.ping.v1.PingService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { From 741312f944d76f960ab301a1cfd6375af00e3ff7 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 6 Nov 2023 17:50:58 -0500 Subject: [PATCH 3/6] Remove test invariants --- client_ext_test.go | 91 ++++++++++++++-------------------------------- 1 file changed, 27 insertions(+), 64 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index e5f6995a..d04f20b3 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -196,71 +196,34 @@ func TestSpecSchema(t *testing.T) { connect.WithInterceptors(&assertSchemaInterceptor{t}), )) server := memhttptest.NewServer(t, mux) - testcases := []struct { - name string - opts []connect.ClientOption - }{{ - name: connect.ProtocolConnect, - }, { - name: connect.ProtocolGRPC, - opts: []connect.ClientOption{ - connect.WithGRPC(), - }, - }, { - name: connect.ProtocolGRPCWeb, - opts: []connect.ClientOption{ - connect.WithGRPC(), - }, - }} - for _, testcase := range testcases { - testcase := testcase - t.Run(testcase.name, func(t *testing.T) { - ctx := context.Background() - client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL(), - connect.WithClientOptions(testcase.opts...), - connect.WithInterceptors(&assertSchemaInterceptor{t}), - ) - t.Parallel() - t.Run("unary", func(t *testing.T) { - unaryReq := connect.NewRequest[pingv1.PingRequest](nil) - _, err := client.Ping(ctx, unaryReq) - assert.Nil(t, err) - text := strings.Repeat(".", 256) - r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) - assert.Nil(t, err) - assert.Equal(t, r.Msg.Text, text) - }) - t.Run("client_stream", func(t *testing.T) { - clientStream := client.Sum(ctx) - t.Cleanup(func() { - _, closeErr := clientStream.CloseAndReceive() - assert.Nil(t, closeErr) - }) - assert.NotZero(t, clientStream.Spec().Schema) - err := clientStream.Send(&pingv1.SumRequest{}) - assert.Nil(t, err) - }) - t.Run("server_stream", func(t *testing.T) { - serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) - t.Cleanup(func() { - assert.Nil(t, serverStream.Close()) - }) - assert.Nil(t, err) - }) - t.Run("bidi_stream", func(t *testing.T) { - bidiStream := client.CumSum(ctx) - t.Cleanup(func() { - assert.Nil(t, bidiStream.CloseRequest()) - assert.Nil(t, bidiStream.CloseResponse()) - }) - assert.NotZero(t, bidiStream.Spec().Schema) - err := bidiStream.Send(&pingv1.CumSumRequest{}) - assert.Nil(t, err) - }) + ctx := context.Background() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + connect.WithInterceptors(&assertSchemaInterceptor{t}), + ) + t.Run("unary", func(t *testing.T) { + t.Parallel() + unaryReq := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, unaryReq) + assert.NotNil(t, unaryReq.Spec().Schema) + assert.Nil(t, err) + text := strings.Repeat(".", 256) + r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) + assert.Nil(t, err) + assert.Equal(t, r.Msg.Text, text) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + bidiStream := client.CumSum(ctx) + t.Cleanup(func() { + assert.Nil(t, bidiStream.CloseRequest()) + assert.Nil(t, bidiStream.CloseResponse()) }) - } + assert.NotZero(t, bidiStream.Spec().Schema) + err := bidiStream.Send(&pingv1.CumSumRequest{}) + assert.Nil(t, err) + }) } type notModifiedPingServer struct { From 9f34b83380b22a0a9308c449db7161534c100838 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 6 Nov 2023 17:53:26 -0500 Subject: [PATCH 4/6] Cleanup schema assert test --- client_ext_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index d04f20b3..16225491 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -281,7 +281,7 @@ type assertSchemaInterceptor struct { func (a *assertSchemaInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { if !assert.NotNil(a.tb, req.Spec().Schema) { - return nil, fmt.Errorf("nil spec") + return next(ctx, req) } methodDesc, ok := req.Spec().Schema.(protoreflect.MethodDescriptor) assert.True(a.tb, ok) @@ -308,7 +308,7 @@ func (a *assertSchemaInterceptor) WrapStreamingClient(next connect.StreamingClie func (a *assertSchemaInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { if !assert.NotNil(a.tb, conn.Spec().Schema) { - return fmt.Errorf("nil spec") + return next(ctx, conn) } methodDesc, ok := conn.Spec().Schema.(protoreflect.MethodDescriptor) assert.True(a.tb, ok) From 4db7b14ab99c3efd5a28c29c9de4b3ee088071dd Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 8 Nov 2023 12:30:36 -0500 Subject: [PATCH 5/6] Switch to global vars for descriptors --- cmd/protoc-gen-connect-go/main.go | 47 ++++++++++++------- .../v1/collidev1connect/collide.connect.go | 12 +++-- .../v1/importv1connect/import.connect.go | 7 ++- .../ping/v1/pingv1connect/ping.connect.go | 32 ++++++++----- 4 files changed, 65 insertions(+), 33 deletions(-) diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index fe4d0f6c..b035dee8 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -124,8 +124,9 @@ func generate(plugin *protogen.Plugin, file *protogen.File) { generatedFile.Import(file.GoImportPath) generatePreamble(generatedFile, file) generateServiceNameConstants(generatedFile, file.Services) + generateServiceNameVariables(generatedFile, file) for _, service := range file.Services { - generateService(generatedFile, file, service) + generateService(generatedFile, service) } } @@ -210,12 +211,30 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge g.P() } -func generateService(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service) { +func generateServiceNameVariables(g *protogen.GeneratedFile, file *protogen.File) { + wrapComments(g, "These variables are the protoreflect.Descriptor objects for the ", file.Desc.Name(), + " service's methods.") + g.P("var (") + for _, service := range file.Services { + serviceDescName := unexport(fmt.Sprintf("%sServiceDescriptor", service.Desc.Name())) + g.P(serviceDescName, ` = `, + g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`) + for _, method := range service.Methods { + g.P(procedureVarMethodDescriptor(method), ` = `, + serviceDescName, + `.Methods().ByName("`, method.Desc.Name(), `")`) + } + } + g.P(")") +} + +func generateService(g *protogen.GeneratedFile, service *protogen.Service) { names := newNames(service) generateClientInterface(g, service, names) - generateClientImplementation(g, file, service, names) + generateClientImplementation(g, service, names) generateServerInterface(g, service, names) - generateServerConstructor(g, file, service, names) + generateServerConstructor(g, service, names) generateUnimplementedServerImplementation(g, service, names) } @@ -240,7 +259,7 @@ func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Servic g.P() } -func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service, names names) { +func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names) { clientOption := connectPackage.Ident("ClientOption") // Client constructor. @@ -260,8 +279,6 @@ func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File ", baseURL string, opts ...", clientOption, ") ", names.Client, " {") if len(service.Methods) > 0 { g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) - g.P("serviceDescriptor := ", g.QualifiedGoIdent(file.GoDescriptorIdent), - `.Services().ByName("`, service.Desc.Name(), `")`) } g.P("return &", names.ClientImpl, "{") for _, method := range service.Methods { @@ -272,8 +289,7 @@ func generateClientImplementation(g *protogen.GeneratedFile, file *protogen.File ) g.P("httpClient,") g.P(`baseURL + `, procedureConstName(method), `,`) - g.P(connectPackage.Ident("WithSchema"), "(", - `serviceDescriptor.Methods().ByName("`, method.Desc.Name(), `")),`) + g.P(connectPackage.Ident("WithSchema"), "(", procedureVarMethodDescriptor(method), "),") idempotency := methodIdempotency(method) switch idempotency { case connect.IdempotencyNoSideEffects: @@ -379,7 +395,7 @@ func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Servic g.P() } -func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, service *protogen.Service, names names) { +func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Service, names names) { wrapComments(g, names.ServerConstructor, " builds an HTTP handler from the service implementation.", " It returns the path on which to mount the handler and the handler itself.") g.P("//") @@ -392,10 +408,6 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s handlerOption := connectPackage.Ident("HandlerOption") g.P("func ", names.ServerConstructor, "(svc ", names.Server, ", opts ...", handlerOption, ") (string, ", httpPackage.Ident("Handler"), ") {") - if len(service.Methods) > 0 { - g.P("serviceDescriptor := ", g.QualifiedGoIdent(file.GoDescriptorIdent), - `.Services().ByName("`, service.Desc.Name(), `")`) - } for _, method := range service.Methods { isStreamingServer := method.Desc.IsStreamingServer() isStreamingClient := method.Desc.IsStreamingClient() @@ -412,8 +424,7 @@ func generateServerConstructor(g *protogen.GeneratedFile, file *protogen.File, s } g.P(procedureConstName(method), `,`) g.P("svc.", method.GoName, ",") - g.P(connectPackage.Ident("WithSchema"), "(", - `serviceDescriptor.Methods().ByName("`, method.Desc.Name(), `")),`) + g.P(connectPackage.Ident("WithSchema"), "(", procedureVarMethodDescriptor(method), "),") switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") @@ -509,6 +520,10 @@ func procedureHandlerName(m *protogen.Method) string { return fmt.Sprintf("%s%sHandler", unexport(m.Parent.GoName), m.GoName) } +func procedureVarMethodDescriptor(m *protogen.Method) string { + return unexport(fmt.Sprintf("%s%sMethodDescriptor", m.Parent.GoName, m.GoName)) +} + func isDeprecatedService(service *protogen.Service) bool { serviceOptions, ok := service.Desc.Options().(*descriptorpb.ServiceOptions) return ok && serviceOptions.GetDeprecated() diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index 9bee99e4..4d5edb74 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -51,6 +51,12 @@ const ( CollideServiceImportProcedure = "/connect.collide.v1.CollideService/Import" ) +// These variables are the protoreflect.Descriptor objects for the v1 service's methods. +var ( + collideServiceServiceDescriptor = v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") + collideServiceImportMethodDescriptor = collideServiceServiceDescriptor.Methods().ByName("Import") +) + // CollideServiceClient is a client for the connect.collide.v1.CollideService service. type CollideServiceClient interface { Import(context.Context, *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) @@ -65,12 +71,11 @@ type CollideServiceClient interface { // http://api.acme.com or https://acme.com/grpc). func NewCollideServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) CollideServiceClient { baseURL = strings.TrimRight(baseURL, "/") - serviceDescriptor := v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") return &collideServiceClient{ _import: connect.NewClient[v1.ImportRequest, v1.ImportResponse]( httpClient, baseURL+CollideServiceImportProcedure, - connect.WithSchema(serviceDescriptor.Methods().ByName("Import")), + connect.WithSchema(collideServiceImportMethodDescriptor), connect.WithClientOptions(opts...), ), } @@ -97,11 +102,10 @@ type CollideServiceHandler interface { // By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf // and JSON codecs. They also support gzip compression. func NewCollideServiceHandler(svc CollideServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { - serviceDescriptor := v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") collideServiceImportHandler := connect.NewUnaryHandler( CollideServiceImportProcedure, svc.Import, - connect.WithSchema(serviceDescriptor.Methods().ByName("Import")), + connect.WithSchema(collideServiceImportMethodDescriptor), connect.WithHandlerOptions(opts...), ) return "/connect.collide.v1.CollideService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/gen/connect/import/v1/importv1connect/import.connect.go b/internal/gen/connect/import/v1/importv1connect/import.connect.go index 043ff35f..47536c27 100644 --- a/internal/gen/connect/import/v1/importv1connect/import.connect.go +++ b/internal/gen/connect/import/v1/importv1connect/import.connect.go @@ -20,7 +20,7 @@ package importv1connect import ( connect "connectrpc.com/connect" - _ "connectrpc.com/connect/internal/gen/connect/import/v1" + v1 "connectrpc.com/connect/internal/gen/connect/import/v1" http "net/http" ) @@ -36,6 +36,11 @@ const ( ImportServiceName = "connect.import.v1.ImportService" ) +// These variables are the protoreflect.Descriptor objects for the v1 service's methods. +var ( + importServiceServiceDescriptor = v1.File_connect_import_v1_import_proto.Services().ByName("ImportService") +) + // ImportServiceClient is a client for the connect.import.v1.ImportService service. type ImportServiceClient interface { } diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index e60a15d1..954ff97d 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -64,6 +64,16 @@ const ( PingServiceCumSumProcedure = "/connect.ping.v1.PingService/CumSum" ) +// These variables are the protoreflect.Descriptor objects for the v1 service's methods. +var ( + pingServiceServiceDescriptor = v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") + pingServicePingMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("Ping") + pingServiceFailMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("Fail") + pingServiceSumMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("Sum") + pingServiceCountUpMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("CountUp") + pingServiceCumSumMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("CumSum") +) + // PingServiceClient is a client for the connect.ping.v1.PingService service. type PingServiceClient interface { // Ping sends a ping to the server to determine if it's reachable. @@ -87,37 +97,36 @@ type PingServiceClient interface { // http://api.acme.com or https://acme.com/grpc). func NewPingServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) PingServiceClient { baseURL = strings.TrimRight(baseURL, "/") - serviceDescriptor := v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") return &pingServiceClient{ ping: connect.NewClient[v1.PingRequest, v1.PingResponse]( httpClient, baseURL+PingServicePingProcedure, - connect.WithSchema(serviceDescriptor.Methods().ByName("Ping")), + connect.WithSchema(pingServicePingMethodDescriptor), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithClientOptions(opts...), ), fail: connect.NewClient[v1.FailRequest, v1.FailResponse]( httpClient, baseURL+PingServiceFailProcedure, - connect.WithSchema(serviceDescriptor.Methods().ByName("Fail")), + connect.WithSchema(pingServiceFailMethodDescriptor), connect.WithClientOptions(opts...), ), sum: connect.NewClient[v1.SumRequest, v1.SumResponse]( httpClient, baseURL+PingServiceSumProcedure, - connect.WithSchema(serviceDescriptor.Methods().ByName("Sum")), + connect.WithSchema(pingServiceSumMethodDescriptor), connect.WithClientOptions(opts...), ), countUp: connect.NewClient[v1.CountUpRequest, v1.CountUpResponse]( httpClient, baseURL+PingServiceCountUpProcedure, - connect.WithSchema(serviceDescriptor.Methods().ByName("CountUp")), + connect.WithSchema(pingServiceCountUpMethodDescriptor), connect.WithClientOptions(opts...), ), cumSum: connect.NewClient[v1.CumSumRequest, v1.CumSumResponse]( httpClient, baseURL+PingServiceCumSumProcedure, - connect.WithSchema(serviceDescriptor.Methods().ByName("CumSum")), + connect.WithSchema(pingServiceCumSumMethodDescriptor), connect.WithClientOptions(opts...), ), } @@ -177,36 +186,35 @@ type PingServiceHandler interface { // By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf // and JSON codecs. They also support gzip compression. func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { - serviceDescriptor := v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") pingServicePingHandler := connect.NewUnaryHandler( PingServicePingProcedure, svc.Ping, - connect.WithSchema(serviceDescriptor.Methods().ByName("Ping")), + connect.WithSchema(pingServicePingMethodDescriptor), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithHandlerOptions(opts...), ) pingServiceFailHandler := connect.NewUnaryHandler( PingServiceFailProcedure, svc.Fail, - connect.WithSchema(serviceDescriptor.Methods().ByName("Fail")), + connect.WithSchema(pingServiceFailMethodDescriptor), connect.WithHandlerOptions(opts...), ) pingServiceSumHandler := connect.NewClientStreamHandler( PingServiceSumProcedure, svc.Sum, - connect.WithSchema(serviceDescriptor.Methods().ByName("Sum")), + connect.WithSchema(pingServiceSumMethodDescriptor), connect.WithHandlerOptions(opts...), ) pingServiceCountUpHandler := connect.NewServerStreamHandler( PingServiceCountUpProcedure, svc.CountUp, - connect.WithSchema(serviceDescriptor.Methods().ByName("CountUp")), + connect.WithSchema(pingServiceCountUpMethodDescriptor), connect.WithHandlerOptions(opts...), ) pingServiceCumSumHandler := connect.NewBidiStreamHandler( PingServiceCumSumProcedure, svc.CumSum, - connect.WithSchema(serviceDescriptor.Methods().ByName("CumSum")), + connect.WithSchema(pingServiceCumSumMethodDescriptor), connect.WithHandlerOptions(opts...), ) return "/connect.ping.v1.PingService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { From 3e40b5a76734e17f4b7a3bd13e83282d5c008ceb Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 8 Nov 2023 17:47:46 -0500 Subject: [PATCH 6/6] Simplify var docs and fix test assert --- client_ext_test.go | 21 +++++++++++-------- cmd/protoc-gen-connect-go/main.go | 3 +-- .../v1/collidev1connect/collide.connect.go | 2 +- .../v1/importv1connect/import.connect.go | 2 +- .../ping/v1/pingv1connect/ping.connect.go | 2 +- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 16225491..18565f3b 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -284,9 +284,10 @@ func (a *assertSchemaInterceptor) WrapUnary(next connect.UnaryFunc) connect.Unar return next(ctx, req) } methodDesc, ok := req.Spec().Schema.(protoreflect.MethodDescriptor) - assert.True(a.tb, ok) - procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) - assert.Equal(a.tb, procedure, req.Spec().Procedure) + if assert.True(a.tb, ok) { + procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) + assert.Equal(a.tb, procedure, req.Spec().Procedure) + } return next(ctx, req) } } @@ -298,9 +299,10 @@ func (a *assertSchemaInterceptor) WrapStreamingClient(next connect.StreamingClie return conn } methodDescriptor, ok := spec.Schema.(protoreflect.MethodDescriptor) - assert.True(a.tb, ok) - procedure := fmt.Sprintf("/%s/%s", methodDescriptor.Parent().FullName(), methodDescriptor.Name()) - assert.Equal(a.tb, procedure, spec.Procedure) + if assert.True(a.tb, ok) { + procedure := fmt.Sprintf("/%s/%s", methodDescriptor.Parent().FullName(), methodDescriptor.Name()) + assert.Equal(a.tb, procedure, spec.Procedure) + } return conn } } @@ -311,9 +313,10 @@ func (a *assertSchemaInterceptor) WrapStreamingHandler(next connect.StreamingHan return next(ctx, conn) } methodDesc, ok := conn.Spec().Schema.(protoreflect.MethodDescriptor) - assert.True(a.tb, ok) - procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) - assert.Equal(a.tb, procedure, conn.Spec().Procedure) + if assert.True(a.tb, ok) { + procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) + assert.Equal(a.tb, procedure, conn.Spec().Procedure) + } return next(ctx, conn) } } diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index b035dee8..6f626f35 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -212,8 +212,7 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge } func generateServiceNameVariables(g *protogen.GeneratedFile, file *protogen.File) { - wrapComments(g, "These variables are the protoreflect.Descriptor objects for the ", file.Desc.Name(), - " service's methods.") + wrapComments(g, "These variables are the protoreflect.Descriptor objects for the RPCs defined in this package.") g.P("var (") for _, service := range file.Services { serviceDescName := unexport(fmt.Sprintf("%sServiceDescriptor", service.Desc.Name())) diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index 4d5edb74..c686fcfc 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -51,7 +51,7 @@ const ( CollideServiceImportProcedure = "/connect.collide.v1.CollideService/Import" ) -// These variables are the protoreflect.Descriptor objects for the v1 service's methods. +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. var ( collideServiceServiceDescriptor = v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") collideServiceImportMethodDescriptor = collideServiceServiceDescriptor.Methods().ByName("Import") diff --git a/internal/gen/connect/import/v1/importv1connect/import.connect.go b/internal/gen/connect/import/v1/importv1connect/import.connect.go index 47536c27..1f6ba74d 100644 --- a/internal/gen/connect/import/v1/importv1connect/import.connect.go +++ b/internal/gen/connect/import/v1/importv1connect/import.connect.go @@ -36,7 +36,7 @@ const ( ImportServiceName = "connect.import.v1.ImportService" ) -// These variables are the protoreflect.Descriptor objects for the v1 service's methods. +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. var ( importServiceServiceDescriptor = v1.File_connect_import_v1_import_proto.Services().ByName("ImportService") ) diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index 954ff97d..56bf4b2c 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -64,7 +64,7 @@ const ( PingServiceCumSumProcedure = "/connect.ping.v1.PingService/CumSum" ) -// These variables are the protoreflect.Descriptor objects for the v1 service's methods. +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. var ( pingServiceServiceDescriptor = v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") pingServicePingMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("Ping")