Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Schema field to Spec for introspection #629

Merged
merged 6 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ type clientConfig struct {
URL *url.URL
Protocol protocol
Procedure string
Schema any
CompressMinBytes int
Interceptor Interceptor
CompressionPools map[string]*compressionPool
Expand Down Expand Up @@ -251,6 +252,7 @@ func (c *clientConfig) newSpec(t StreamType) Spec {
return Spec{
StreamType: t,
Procedure: c.Procedure,
Schema: c.Schema,
IsClient: true,
IdempotencyLevel: c.IdempotencyLevel,
}
Expand Down
87 changes: 87 additions & 0 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package connect_test
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"testing"
Expand All @@ -26,6 +27,7 @@ import (
pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1"
"connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect"
"connectrpc.com/connect/internal/memhttp/memhttptest"
"google.golang.org/protobuf/reflect/protoreflect"
)

func TestNewClient_InitFailure(t *testing.T) {
Expand Down Expand Up @@ -186,6 +188,44 @@ func TestGetNotModified(t *testing.T) {
assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod())
}

func TestSpecSchema(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(
pingServer{},
connect.WithInterceptors(&assertSchemaInterceptor{t}),
))
server := memhttptest.NewServer(t, mux)
ctx := context.Background()
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL(),
connect.WithInterceptors(&assertSchemaInterceptor{t}),
)
t.Run("unary", func(t *testing.T) {
t.Parallel()
unaryReq := connect.NewRequest[pingv1.PingRequest](nil)
_, err := client.Ping(ctx, unaryReq)
assert.NotNil(t, unaryReq.Spec().Schema)
assert.Nil(t, err)
text := strings.Repeat(".", 256)
r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text}))
assert.Nil(t, err)
assert.Equal(t, r.Msg.Text, text)
})
t.Run("bidi_stream", func(t *testing.T) {
t.Parallel()
bidiStream := client.CumSum(ctx)
t.Cleanup(func() {
assert.Nil(t, bidiStream.CloseRequest())
assert.Nil(t, bidiStream.CloseResponse())
})
assert.NotZero(t, bidiStream.Spec().Schema)
err := bidiStream.Send(&pingv1.CumSumRequest{})
assert.Nil(t, err)
})
}

type notModifiedPingServer struct {
pingv1connect.UnimplementedPingServiceHandler

Expand Down Expand Up @@ -233,3 +273,50 @@ func (a *assertPeerInterceptor) WrapStreamingHandler(next connect.StreamingHandl
return next(ctx, conn)
}
}

type assertSchemaInterceptor struct {
tb testing.TB
}

func (a *assertSchemaInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if !assert.NotNil(a.tb, req.Spec().Schema) {
return next(ctx, req)
}
methodDesc, ok := req.Spec().Schema.(protoreflect.MethodDescriptor)
if assert.True(a.tb, ok) {
procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name())
assert.Equal(a.tb, procedure, req.Spec().Procedure)
}
return next(ctx, req)
}
}

func (a *assertSchemaInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
conn := next(ctx, spec)
if !assert.NotNil(a.tb, spec.Schema) {
return conn
}
methodDescriptor, ok := spec.Schema.(protoreflect.MethodDescriptor)
if assert.True(a.tb, ok) {
procedure := fmt.Sprintf("/%s/%s", methodDescriptor.Parent().FullName(), methodDescriptor.Name())
assert.Equal(a.tb, procedure, spec.Procedure)
}
return conn
}
}

func (a *assertSchemaInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
if !assert.NotNil(a.tb, conn.Spec().Schema) {
return next(ctx, conn)
}
methodDesc, ok := conn.Spec().Schema.(protoreflect.MethodDescriptor)
if assert.True(a.tb, ok) {
procedure := fmt.Sprintf("/%s/%s", methodDesc.Parent().FullName(), methodDesc.Name())
assert.Equal(a.tb, procedure, conn.Spec().Procedure)
}
return next(ctx, conn)
}
}
53 changes: 30 additions & 23 deletions cmd/protoc-gen-connect-go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,6 @@ func main() {
)
}

func needsWithIdempotency(file *protogen.File) bool {
for _, service := range file.Services {
for _, method := range service.Methods {
if methodIdempotency(method) != connect.IdempotencyUnknown {
return true
}
}
}
return false
}

func generate(plugin *protogen.Plugin, file *protogen.File) {
if len(file.Services) == 0 {
return
Expand All @@ -135,6 +124,7 @@ func generate(plugin *protogen.Plugin, file *protogen.File) {
generatedFile.Import(file.GoImportPath)
generatePreamble(generatedFile, file)
generateServiceNameConstants(generatedFile, file.Services)
generateServiceNameVariables(generatedFile, file)
for _, service := range file.Services {
generateService(generatedFile, service)
}
Expand Down Expand Up @@ -180,11 +170,7 @@ func generatePreamble(g *protogen.GeneratedFile, file *protogen.File) {
"is not defined, this code was generated with a version of connect newer than the one ",
"compiled into your binary. You can fix the problem by either regenerating this code ",
"with an older version of connect or updating the connect version compiled into your binary.")
if needsWithIdempotency(file) {
g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_7_0"))
} else {
g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion0_1_0"))
}
g.P("const _ = ", connectPackage.Ident("IsAtLeastVersion1_13_0"))
g.P()
}

Expand Down Expand Up @@ -225,6 +211,23 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge
g.P()
}

func generateServiceNameVariables(g *protogen.GeneratedFile, file *protogen.File) {
wrapComments(g, "These variables are the protoreflect.Descriptor objects for the RPCs defined in this package.")
jhump marked this conversation as resolved.
Show resolved Hide resolved
g.P("var (")
for _, service := range file.Services {
serviceDescName := unexport(fmt.Sprintf("%sServiceDescriptor", service.Desc.Name()))
g.P(serviceDescName, ` = `,
g.QualifiedGoIdent(file.GoDescriptorIdent),
`.Services().ByName("`, service.Desc.Name(), `")`)
for _, method := range service.Methods {
g.P(procedureVarMethodDescriptor(method), ` = `,
serviceDescName,
`.Methods().ByName("`, method.Desc.Name(), `")`)
}
}
g.P(")")
}

func generateService(g *protogen.GeneratedFile, service *protogen.Service) {
names := newNames(service)
generateClientInterface(g, service, names)
Expand Down Expand Up @@ -273,7 +276,9 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S
}
g.P("func ", names.ClientConstructor, " (httpClient ", connectPackage.Ident("HTTPClient"),
", baseURL string, opts ...", clientOption, ") ", names.Client, " {")
g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`)
if len(service.Methods) > 0 {
g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`)
}
g.P("return &", names.ClientImpl, "{")
for _, method := range service.Methods {
g.P(unexport(method.GoName), ": ",
Expand All @@ -283,17 +288,16 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S
)
g.P("httpClient,")
g.P(`baseURL + `, procedureConstName(method), `,`)
g.P(connectPackage.Ident("WithSchema"), "(", procedureVarMethodDescriptor(method), "),")
idempotency := methodIdempotency(method)
switch idempotency {
case connect.IdempotencyNoSideEffects:
g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),")
g.P(connectPackage.Ident("WithClientOptions"), "(opts...),")
case connect.IdempotencyIdempotent:
g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),")
g.P(connectPackage.Ident("WithClientOptions"), "(opts...),")
case connect.IdempotencyUnknown:
g.P("opts...,")
}
g.P(connectPackage.Ident("WithClientOptions"), "(opts...),")
g.P("),")
}
g.P("}")
Expand Down Expand Up @@ -419,16 +423,15 @@ func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Serv
}
g.P(procedureConstName(method), `,`)
g.P("svc.", method.GoName, ",")
g.P(connectPackage.Ident("WithSchema"), "(", procedureVarMethodDescriptor(method), "),")
switch idempotency {
case connect.IdempotencyNoSideEffects:
g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyNoSideEffects"), "),")
g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),")
case connect.IdempotencyIdempotent:
g.P(connectPackage.Ident("WithIdempotency"), "(", connectPackage.Ident("IdempotencyIdempotent"), "),")
g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),")
case connect.IdempotencyUnknown:
g.P("opts...,")
}
g.P(connectPackage.Ident("WithHandlerOptions"), "(opts...),")
g.P(")")
}
g.P(`return "/`, service.Desc.FullName(), `/", `, httpPackage.Ident("HandlerFunc"), `(func(w `, httpPackage.Ident("ResponseWriter"), `, r *`, httpPackage.Ident("Request"), `){`)
Expand Down Expand Up @@ -516,6 +519,10 @@ func procedureHandlerName(m *protogen.Method) string {
return fmt.Sprintf("%s%sHandler", unexport(m.Parent.GoName), m.GoName)
}

func procedureVarMethodDescriptor(m *protogen.Method) string {
return unexport(fmt.Sprintf("%s%sMethodDescriptor", m.Parent.GoName, m.GoName))
}

func isDeprecatedService(service *protogen.Service) bool {
serviceOptions, ok := service.Desc.Options().(*descriptorpb.ServiceOptions)
return ok && serviceOptions.GetDeprecated()
Expand Down
8 changes: 5 additions & 3 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ const Version = "1.13.0-dev"
// These constants are used in compile-time handshakes with connect's generated
// code.
const (
IsAtLeastVersion0_0_1 = true
IsAtLeastVersion0_1_0 = true
IsAtLeastVersion1_7_0 = true
IsAtLeastVersion0_0_1 = true
IsAtLeastVersion0_1_0 = true
IsAtLeastVersion1_7_0 = true
IsAtLeastVersion1_13_0 = true
)

// StreamType describes whether the client, server, neither, or both is
Expand Down Expand Up @@ -314,6 +315,7 @@ type HTTPClient interface {
// fully-qualified Procedure corresponding to each RPC in your schema.
type Spec struct {
StreamType StreamType
Schema any // for protobuf RPCs, a protoreflect.MethodDescriptor
Procedure string // for example, "/acme.foo.v1.FooService/Bar"
IsClient bool // otherwise we're in a handler
IdempotencyLevel IdempotencyLevel
Expand Down
2 changes: 2 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ type handlerConfig struct {
CompressMinBytes int
Interceptor Interceptor
Procedure string
Schema any
HandleGRPC bool
HandleGRPCWeb bool
RequireConnectProtocolHeader bool
Expand Down Expand Up @@ -279,6 +280,7 @@ func newHandlerConfig(procedure string, streamType StreamType, options []Handler
func (c *handlerConfig) newSpec() Spec {
return Spec{
Procedure: c.Procedure,
Schema: c.Schema,
StreamType: c.StreamType,
IdempotencyLevel: c.IdempotencyLevel,
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions internal/gen/connect/import/v1/importv1connect/import.connect.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading