diff --git a/client.go b/client.go index 38bb541b..2abcaa5e 100644 --- a/client.go +++ b/client.go @@ -189,6 +189,7 @@ type clientConfig struct { URL *url.URL Protocol protocol Procedure string + Schema any CompressMinBytes int Interceptor Interceptor CompressionPools map[string]*compressionPool @@ -251,6 +252,7 @@ func (c *clientConfig) newSpec(t StreamType) Spec { return Spec{ StreamType: t, Procedure: c.Procedure, + Schema: c.Schema, IsClient: true, IdempotencyLevel: c.IdempotencyLevel, } diff --git a/client_ext_test.go b/client_ext_test.go index ce799958..18565f3b 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -17,6 +17,7 @@ package connect_test import ( "context" "errors" + "fmt" "net/http" "strings" "testing" @@ -26,6 +27,7 @@ import ( pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" "connectrpc.com/connect/internal/memhttp/memhttptest" + "google.golang.org/protobuf/reflect/protoreflect" ) func TestNewClient_InitFailure(t *testing.T) { @@ -186,6 +188,44 @@ func TestGetNotModified(t *testing.T) { assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod()) } +func TestSpecSchema(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler( + pingServer{}, + connect.WithInterceptors(&assertSchemaInterceptor{t}), + )) + server := memhttptest.NewServer(t, mux) + ctx := context.Background() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL(), + connect.WithInterceptors(&assertSchemaInterceptor{t}), + ) + t.Run("unary", func(t *testing.T) { + t.Parallel() + unaryReq := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, unaryReq) + assert.NotNil(t, unaryReq.Spec().Schema) + assert.Nil(t, err) + text := strings.Repeat(".", 256) + r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) + assert.Nil(t, err) + assert.Equal(t, r.Msg.Text, text) + }) + t.Run("bidi_stream", func(t *testing.T) { + t.Parallel() + bidiStream := client.CumSum(ctx) + t.Cleanup(func() { + assert.Nil(t, bidiStream.CloseRequest()) + assert.Nil(t, bidiStream.CloseResponse()) + }) + assert.NotZero(t, bidiStream.Spec().Schema) + err := bidiStream.Send(&pingv1.CumSumRequest{}) + assert.Nil(t, err) + }) +} + type notModifiedPingServer struct { pingv1connect.UnimplementedPingServiceHandler @@ -233,3 +273,50 @@ func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandl return next(ctx, conn) } } + +type assertSchemaInterceptor struct { + tb testing.TB +} + +func (a *assertSchemaInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + if !assert.NotNil(a.tb, req.Spec().Schema) { + return next(ctx, req) + } + methodDesc, ok := req.Spec().Schema.(protoreflect.MethodDescriptor) + if assert.True(a.tb, ok) { + procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) + assert.Equal(a.tb, procedure, req.Spec().Procedure) + } + return next(ctx, req) + } +} + +func (a *assertSchemaInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + conn := next(ctx, spec) + if !assert.NotNil(a.tb, spec.Schema) { + return conn + } + methodDescriptor, ok := spec.Schema.(protoreflect.MethodDescriptor) + if assert.True(a.tb, ok) { + procedure := fmt.Sprintf("/%s/%s", methodDescriptor.Parent().FullName(), methodDescriptor.Name()) + assert.Equal(a.tb, procedure, spec.Procedure) + } + return conn + } +} + +func (a *assertSchemaInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + if !assert.NotNil(a.tb, conn.Spec().Schema) { + return next(ctx, conn) + } + methodDesc, ok := conn.Spec().Schema.(protoreflect.MethodDescriptor) + if assert.True(a.tb, ok) { + procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name()) + assert.Equal(a.tb, procedure, conn.Spec().Procedure) + } + return next(ctx, conn) + } +} diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index bde19f51..6f626f35 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -102,17 +102,6 @@ func main() { ) } -func needsWithIdempotency(file *protogen.File) bool { - for _, service := range file.Services { - for _, method := range service.Methods { - if methodIdempotency(method) != connect.IdempotencyUnknown { - return true - } - } - } - return false -} - func generate(plugin *protogen.Plugin, file *protogen.File) { if len(file.Services) == 0 { return @@ -135,6 +124,7 @@ func generate(plugin *protogen.Plugin, file *protogen.File) { generatedFile.Import(file.GoImportPath) generatePreamble(generatedFile, file) generateServiceNameConstants(generatedFile, file.Services) + generateServiceNameVariables(generatedFile, file) for _, service := range file.Services { generateService(generatedFile, service) } @@ -180,11 +170,7 @@ func generatePreamble(g *protogen.GeneratedFile, file *protogen.File) { "is not defined, this code was generated with a version of connect newer than the one ", "compiled into your binary. You can fix the problem by either regenerating this code ", "with an older version of connect or updating the connect version compiled into your binary.") - if needsWithIdempotency(file) { - g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_7_0")) - } else { - g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion0_1_0")) - } + g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_13_0")) g.P() } @@ -225,6 +211,23 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge g.P() } +func generateServiceNameVariables(g *protogen.GeneratedFile, file *protogen.File) { + wrapComments(g, "These variables are the protoreflect.Descriptor objects for the RPCs defined in this package.") + g.P("var (") + for _, service := range file.Services { + serviceDescName := unexport(fmt.Sprintf("%sServiceDescriptor", service.Desc.Name())) + g.P(serviceDescName, ` = `, + g.QualifiedGoIdent(file.GoDescriptorIdent), + `.Services().ByName("`, service.Desc.Name(), `")`) + for _, method := range service.Methods { + g.P(procedureVarMethodDescriptor(method), ` = `, + serviceDescName, + `.Methods().ByName("`, method.Desc.Name(), `")`) + } + } + g.P(")") +} + func generateService(g *protogen.GeneratedFile, service *protogen.Service) { names := newNames(service) generateClientInterface(g, service, names) @@ -273,7 +276,9 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S } g.P("func ", names.ClientConstructor, " (httpClient ", connectPackage.Ident("HTTPClient"), ", baseURL string, opts ...", clientOption, ") ", names.Client, " {") - g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) + if len(service.Methods) > 0 { + g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) + } g.P("return &", names.ClientImpl, "{") for _, method := range service.Methods { g.P(unexport(method.GoName), ": ", @@ -283,17 +288,16 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S ) g.P("httpClient,") g.P(`baseURL + `, procedureConstName(method), `,`) + g.P(connectPackage.Ident("WithSchema"), "(", procedureVarMethodDescriptor(method), "),") idempotency := methodIdempotency(method) switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") - g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") case connect.IdempotencyIdempotent: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),") - g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") case connect.IdempotencyUnknown: - g.P("opts...,") } + g.P(connectPackage.Ident("WithClientOptions"), "(opts...),") g.P("),") } g.P("}") @@ -419,16 +423,15 @@ func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Serv } g.P(procedureConstName(method), `,`) g.P("svc.", method.GoName, ",") + g.P(connectPackage.Ident("WithSchema"), "(", procedureVarMethodDescriptor(method), "),") switch idempotency { case connect.IdempotencyNoSideEffects: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),") - g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") case connect.IdempotencyIdempotent: g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),") - g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") case connect.IdempotencyUnknown: - g.P("opts...,") } + g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),") g.P(")") } g.P(`return "/`, service.Desc.FullName(), `/", `, httpPackage.Ident("HandlerFunc"), `(func(w `, httpPackage.Ident("ResponseWriter"), `, r *`, httpPackage.Ident("Request"), `){`) @@ -516,6 +519,10 @@ func procedureHandlerName(m *protogen.Method) string { return fmt.Sprintf("%s%sHandler", unexport(m.Parent.GoName), m.GoName) } +func procedureVarMethodDescriptor(m *protogen.Method) string { + return unexport(fmt.Sprintf("%s%sMethodDescriptor", m.Parent.GoName, m.GoName)) +} + func isDeprecatedService(service *protogen.Service) bool { serviceOptions, ok := service.Desc.Options().(*descriptorpb.ServiceOptions) return ok && serviceOptions.GetDeprecated() diff --git a/connect.go b/connect.go index c7c41d38..622852fb 100644 --- a/connect.go +++ b/connect.go @@ -38,9 +38,10 @@ const Version = "1.13.0-dev" // These constants are used in compile-time handshakes with connect's generated // code. const ( - IsAtLeastVersion0_0_1 = true - IsAtLeastVersion0_1_0 = true - IsAtLeastVersion1_7_0 = true + IsAtLeastVersion0_0_1 = true + IsAtLeastVersion0_1_0 = true + IsAtLeastVersion1_7_0 = true + IsAtLeastVersion1_13_0 = true ) // StreamType describes whether the client, server, neither, or both is @@ -314,6 +315,7 @@ type HTTPClient interface { // fully-qualified Procedure corresponding to each RPC in your schema. type Spec struct { StreamType StreamType + Schema any // for protobuf RPCs, a protoreflect.MethodDescriptor Procedure string // for example, "/acme.foo.v1.FooService/Bar" IsClient bool // otherwise we're in a handler IdempotencyLevel IdempotencyLevel diff --git a/handler.go b/handler.go index 43bfe973..b207ac66 100644 --- a/handler.go +++ b/handler.go @@ -246,6 +246,7 @@ type handlerConfig struct { CompressMinBytes int Interceptor Interceptor Procedure string + Schema any HandleGRPC bool HandleGRPCWeb bool RequireConnectProtocolHeader bool @@ -279,6 +280,7 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler func (c *handlerConfig) newSpec() Spec { return Spec{ Procedure: c.Procedure, + Schema: c.Schema, StreamType: c.StreamType, IdempotencyLevel: c.IdempotencyLevel, } diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index 20223410..c686fcfc 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -32,7 +32,7 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion0_1_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // CollideServiceName is the fully-qualified name of the CollideService service. @@ -51,6 +51,12 @@ const ( CollideServiceImportProcedure = "/connect.collide.v1.CollideService/Import" ) +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. +var ( + collideServiceServiceDescriptor = v1.File_connect_collide_v1_collide_proto.Services().ByName("CollideService") + collideServiceImportMethodDescriptor = collideServiceServiceDescriptor.Methods().ByName("Import") +) + // CollideServiceClient is a client for the connect.collide.v1.CollideService service. type CollideServiceClient interface { Import(context.Context, *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) @@ -69,7 +75,8 @@ func NewCollideServiceClient(httpClient connect.HTTPClient, baseURL string, opts _import: connect.NewClient[v1.ImportRequest, v1.ImportResponse]( httpClient, baseURL+CollideServiceImportProcedure, - opts..., + connect.WithSchema(collideServiceImportMethodDescriptor), + connect.WithClientOptions(opts...), ), } } @@ -98,7 +105,8 @@ func NewCollideServiceHandler(svc CollideServiceHandler, opts ...connect.Handler collideServiceImportHandler := connect.NewUnaryHandler( CollideServiceImportProcedure, svc.Import, - opts..., + connect.WithSchema(collideServiceImportMethodDescriptor), + connect.WithHandlerOptions(opts...), ) return "/connect.collide.v1.CollideService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { diff --git a/internal/gen/connect/import/v1/importv1connect/import.connect.go b/internal/gen/connect/import/v1/importv1connect/import.connect.go index d56d3efa..1f6ba74d 100644 --- a/internal/gen/connect/import/v1/importv1connect/import.connect.go +++ b/internal/gen/connect/import/v1/importv1connect/import.connect.go @@ -20,9 +20,8 @@ package importv1connect import ( connect "connectrpc.com/connect" - _ "connectrpc.com/connect/internal/gen/connect/import/v1" + v1 "connectrpc.com/connect/internal/gen/connect/import/v1" http "net/http" - strings "strings" ) // This is a compile-time assertion to ensure that this generated file and the connect package are @@ -30,13 +29,18 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion0_1_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // ImportServiceName is the fully-qualified name of the ImportService service. ImportServiceName = "connect.import.v1.ImportService" ) +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. +var ( + importServiceServiceDescriptor = v1.File_connect_import_v1_import_proto.Services().ByName("ImportService") +) + // ImportServiceClient is a client for the connect.import.v1.ImportService service. type ImportServiceClient interface { } @@ -49,7 +53,6 @@ type ImportServiceClient interface { // The URL supplied here should be the base URL for the Connect or gRPC server (for example, // http://api.acme.com or https://acme.com/grpc). func NewImportServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) ImportServiceClient { - baseURL = strings.TrimRight(baseURL, "/") return &importServiceClient{} } diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index 036539a8..56bf4b2c 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -37,7 +37,7 @@ import ( // generated with a version of connect newer than the one compiled into your binary. You can fix the // problem by either regenerating this code with an older version of connect or updating the connect // version compiled into your binary. -const _ = connect.IsAtLeastVersion1_7_0 +const _ = connect.IsAtLeastVersion1_13_0 const ( // PingServiceName is the fully-qualified name of the PingService service. @@ -64,6 +64,16 @@ const ( PingServiceCumSumProcedure = "/connect.ping.v1.PingService/CumSum" ) +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. +var ( + pingServiceServiceDescriptor = v1.File_connect_ping_v1_ping_proto.Services().ByName("PingService") + pingServicePingMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("Ping") + pingServiceFailMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("Fail") + pingServiceSumMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("Sum") + pingServiceCountUpMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("CountUp") + pingServiceCumSumMethodDescriptor = pingServiceServiceDescriptor.Methods().ByName("CumSum") +) + // PingServiceClient is a client for the connect.ping.v1.PingService service. type PingServiceClient interface { // Ping sends a ping to the server to determine if it's reachable. @@ -91,28 +101,33 @@ func NewPingServiceClient(httpClient connect.HTTPClient, baseURL string, opts .. ping: connect.NewClient[v1.PingRequest, v1.PingResponse]( httpClient, baseURL+PingServicePingProcedure, + connect.WithSchema(pingServicePingMethodDescriptor), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithClientOptions(opts...), ), fail: connect.NewClient[v1.FailRequest, v1.FailResponse]( httpClient, baseURL+PingServiceFailProcedure, - opts..., + connect.WithSchema(pingServiceFailMethodDescriptor), + connect.WithClientOptions(opts...), ), sum: connect.NewClient[v1.SumRequest, v1.SumResponse]( httpClient, baseURL+PingServiceSumProcedure, - opts..., + connect.WithSchema(pingServiceSumMethodDescriptor), + connect.WithClientOptions(opts...), ), countUp: connect.NewClient[v1.CountUpRequest, v1.CountUpResponse]( httpClient, baseURL+PingServiceCountUpProcedure, - opts..., + connect.WithSchema(pingServiceCountUpMethodDescriptor), + connect.WithClientOptions(opts...), ), cumSum: connect.NewClient[v1.CumSumRequest, v1.CumSumResponse]( httpClient, baseURL+PingServiceCumSumProcedure, - opts..., + connect.WithSchema(pingServiceCumSumMethodDescriptor), + connect.WithClientOptions(opts...), ), } } @@ -174,28 +189,33 @@ func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption pingServicePingHandler := connect.NewUnaryHandler( PingServicePingProcedure, svc.Ping, + connect.WithSchema(pingServicePingMethodDescriptor), connect.WithIdempotency(connect.IdempotencyNoSideEffects), connect.WithHandlerOptions(opts...), ) pingServiceFailHandler := connect.NewUnaryHandler( PingServiceFailProcedure, svc.Fail, - opts..., + connect.WithSchema(pingServiceFailMethodDescriptor), + connect.WithHandlerOptions(opts...), ) pingServiceSumHandler := connect.NewClientStreamHandler( PingServiceSumProcedure, svc.Sum, - opts..., + connect.WithSchema(pingServiceSumMethodDescriptor), + connect.WithHandlerOptions(opts...), ) pingServiceCountUpHandler := connect.NewServerStreamHandler( PingServiceCountUpProcedure, svc.CountUp, - opts..., + connect.WithSchema(pingServiceCountUpMethodDescriptor), + connect.WithHandlerOptions(opts...), ) pingServiceCumSumHandler := connect.NewBidiStreamHandler( PingServiceCumSumProcedure, svc.CumSum, - opts..., + connect.WithSchema(pingServiceCumSumMethodDescriptor), + connect.WithHandlerOptions(opts...), ) return "/connect.ping.v1.PingService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { diff --git a/option.go b/option.go index 1eb8d7c3..c5d15f05 100644 --- a/option.go +++ b/option.go @@ -184,6 +184,16 @@ type Option interface { HandlerOption } +// WithSchema provides a parsed representation of the schema for an RPC to a +// client or handler. The supplied schema is exposed as [Spec.Schema]. This +// option is typically added by generated code. +// +// For services using protobuf schemas, the supplied schema should be a +// [protoreflect.MethodDescriptor]. +func WithSchema(schema any) Option { + return &schemaOption{Schema: schema} +} + // WithCodec registers a serialization method with a client or handler. // Handlers may have multiple codecs registered, and use whichever the client // chooses. Clients may only have a single codec. @@ -328,6 +338,18 @@ func WithOptions(options ...Option) Option { return &optionsOption{options} } +type schemaOption struct { + Schema any +} + +func (o *schemaOption) applyToClient(config *clientConfig) { + config.Schema = o.Schema +} + +func (o *schemaOption) applyToHandler(config *handlerConfig) { + config.Schema = o.Schema +} + type clientOptionsOption struct { options []ClientOption }