From e9965a674721f484773c59ec216724afc60e1d43 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Fri, 3 Mar 2017 20:23:20 +0100 Subject: [PATCH 1/7] adds gRPC ClientAfter handler and multiple calls to Client Before/After funcs --- transport/grpc/client.go | 21 +++++++++++++++++++-- transport/grpc/request_response_funcs.go | 18 ++++++++++++------ transport/grpc/server.go | 4 ++-- transport/http/client.go | 4 ++-- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 7ab43c647..437d931c1 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -22,6 +22,7 @@ type Client struct { dec DecodeResponseFunc grpcReply reflect.Type before []RequestFunc + after []ClientResponseFunc } // NewClient constructs a usable Client for a single remote endpoint. @@ -54,6 +55,7 @@ func NewClient( ).Interface(), ), before: []RequestFunc{}, + after: []ClientResponseFunc{}, } for _, option := range options { option(c) @@ -67,7 +69,14 @@ type ClientOption func(*Client) // ClientBefore sets the RequestFuncs that are applied to the outgoing gRPC // request before it's invoked. func ClientBefore(before ...RequestFunc) ClientOption { - return func(c *Client) { c.before = before } + return func(c *Client) { c.before = append(c.before, before...) } +} + +// ClientAfter sets the ClientResponseFuncs that are applied to the incoming +// gRPC response prior to it being decoded. This is useful for obtaining +// response metadata and adding onto the context prior to decoding. +func ClientAfter(after ...ClientResponseFunc) ClientOption { + return func(c *Client) { c.after = append(c.after, after...) } } // Endpoint returns a usable endpoint that will invoke the gRPC specified by the @@ -88,11 +97,19 @@ func (c Client) Endpoint() endpoint.Endpoint { } ctx = metadata.NewContext(ctx, *md) + var header, trailer metadata.MD grpcReply := reflect.New(c.grpcReply).Interface() - if err = grpc.Invoke(ctx, c.method, req, grpcReply, c.client); err != nil { + if err = grpc.Invoke( + ctx, c.method, req, grpcReply, c.client, + grpc.Header(&header), grpc.Trailer(&trailer), + ); err != nil { return nil, err } + for _, f := range c.after { + ctx = f(ctx, &header, &trailer) + } + response, err := c.dec(ctx, grpcReply) if err != nil { return nil, err diff --git a/transport/grpc/request_response_funcs.go b/transport/grpc/request_response_funcs.go index aa88ca65e..067ef3f4b 100644 --- a/transport/grpc/request_response_funcs.go +++ b/transport/grpc/request_response_funcs.go @@ -12,20 +12,26 @@ const ( binHdrSuffix = "-bin" ) -// RequestFunc may take information from an gRPC request and put it into a -// request context. In Servers, BeforeFuncs are executed prior to invoking the -// endpoint. In Clients, BeforeFuncs are executed after creating the request +// RequestFunc may take information from a gRPC request and put it into a +// request context. In Servers, RequestFuncs are executed prior to invoking the +// endpoint. In Clients, RequestFuncs are executed after creating the request // but prior to invoking the gRPC client. type RequestFunc func(context.Context, *metadata.MD) context.Context -// ResponseFunc may take information from a request context and use it to +// ServerResponseFunc may take information from a request context and use it to // manipulate the gRPC metadata header. ResponseFuncs are only executed in // servers, after invoking the endpoint but prior to writing a response. -type ResponseFunc func(context.Context, *metadata.MD) +type ServerResponseFunc func(context.Context, *metadata.MD) + +// ClientResponseFunc may take information from a gRPC metadata header and/or +// trailer and make the responses available for consumption. ClientResponseFuncs +// are only executed in clients, after a request has been made, but prior to it +// being decoded. +type ClientResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context // SetResponseHeader returns a ResponseFunc that sets the specified metadata // key-value pair. -func SetResponseHeader(key, val string) ResponseFunc { +func SetResponseHeader(key, val string) ServerResponseFunc { return func(_ context.Context, md *metadata.MD) { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 742c1a086..9f6a94a12 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -24,7 +24,7 @@ type Server struct { dec DecodeRequestFunc enc EncodeResponseFunc before []RequestFunc - after []ResponseFunc + after []ServerResponseFunc logger log.Logger } @@ -64,7 +64,7 @@ func ServerBefore(before ...RequestFunc) ServerOption { // ServerAfter functions are executed on the HTTP response writer after the // endpoint is invoked, but before anything is written to the client. -func ServerAfter(after ...ResponseFunc) ServerOption { +func ServerAfter(after ...ServerResponseFunc) ServerOption { return func(s *Server) { s.after = append(s.after, after...) } } diff --git a/transport/http/client.go b/transport/http/client.go index 494797583..08f1b886e 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -62,14 +62,14 @@ func SetClient(client *http.Client) ClientOption { // ClientBefore sets the RequestFuncs that are applied to the outgoing HTTP // request before it's invoked. func ClientBefore(before ...RequestFunc) ClientOption { - return func(c *Client) { c.before = before } + return func(c *Client) { c.before = append(c.before, before...) } } // ClientAfter sets the ClientResponseFuncs applied to the incoming HTTP // request prior to it being decoded. This is useful for obtaining anything off // of the response and adding onto the context prior to decoding. func ClientAfter(after ...ClientResponseFunc) ClientOption { - return func(c *Client) { c.after = after } + return func(c *Client) { c.after = append(c.after, after...) } } // BufferedStream sets whether the Response.Body is left open, allowing it From 2677e4963bc6c4ee561718ac5d472010a2898ec6 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Sat, 4 Mar 2017 23:13:57 +0100 Subject: [PATCH 2/7] added unit test for gRPC header and trailer request/response propagation --- examples/addsvc/cmd/addsvc/main.go | 2 +- examples/addsvc/transport_grpc.go | 4 +- transport/grpc/_grpc_test/client.go | 39 ++++ transport/grpc/_grpc_test/context_metadata.go | 106 +++++++++++ transport/grpc/_grpc_test/request_response.go | 40 +++++ transport/grpc/_grpc_test/server.go | 57 ++++++ transport/grpc/_grpc_test/service.go | 17 ++ transport/grpc/_pb/generate.go | 3 + transport/grpc/_pb/test.pb.go | 167 ++++++++++++++++++ transport/grpc/_pb/test.proto | 16 ++ transport/grpc/client.go | 2 +- transport/grpc/client_test.go | 59 +++++++ transport/grpc/request_response_funcs.go | 2 +- transport/grpc/server.go | 27 ++- 14 files changed, 518 insertions(+), 23 deletions(-) create mode 100644 transport/grpc/_grpc_test/client.go create mode 100644 transport/grpc/_grpc_test/context_metadata.go create mode 100644 transport/grpc/_grpc_test/request_response.go create mode 100644 transport/grpc/_grpc_test/server.go create mode 100644 transport/grpc/_grpc_test/service.go create mode 100644 transport/grpc/_pb/generate.go create mode 100644 transport/grpc/_pb/test.pb.go create mode 100644 transport/grpc/_pb/test.proto create mode 100644 transport/grpc/client_test.go diff --git a/examples/addsvc/cmd/addsvc/main.go b/examples/addsvc/cmd/addsvc/main.go index 842f34bec..3fba0ada3 100644 --- a/examples/addsvc/cmd/addsvc/main.go +++ b/examples/addsvc/cmd/addsvc/main.go @@ -222,7 +222,7 @@ func main() { return } - srv := addsvc.MakeGRPCServer(ctx, endpoints, tracer, logger) + srv := addsvc.MakeGRPCServer(endpoints, tracer, logger) s := grpc.NewServer() pb.RegisterAddServer(s, srv) diff --git a/examples/addsvc/transport_grpc.go b/examples/addsvc/transport_grpc.go index 21e60bc4f..dcfc03a05 100644 --- a/examples/addsvc/transport_grpc.go +++ b/examples/addsvc/transport_grpc.go @@ -16,20 +16,18 @@ import ( ) // MakeGRPCServer makes a set of endpoints available as a gRPC AddServer. -func MakeGRPCServer(ctx context.Context, endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer { +func MakeGRPCServer(endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer { options := []grpctransport.ServerOption{ grpctransport.ServerErrorLogger(logger), } return &grpcServer{ sum: grpctransport.NewServer( - ctx, endpoints.SumEndpoint, DecodeGRPCSumRequest, EncodeGRPCSumResponse, append(options, grpctransport.ServerBefore(opentracing.FromGRPCRequest(tracer, "Sum", logger)))..., ), concat: grpctransport.NewServer( - ctx, endpoints.ConcatEndpoint, DecodeGRPCConcatRequest, EncodeGRPCConcatResponse, diff --git a/transport/grpc/_grpc_test/client.go b/transport/grpc/_grpc_test/client.go new file mode 100644 index 000000000..11d78ca7f --- /dev/null +++ b/transport/grpc/_grpc_test/client.go @@ -0,0 +1,39 @@ +package test + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/go-kit/kit/endpoint" + grpctransport "github.com/go-kit/kit/transport/grpc" + pb "github.com/go-kit/kit/transport/grpc/_pb" +) + +type clientBinding struct { + test endpoint.Endpoint +} + +func (c *clientBinding) Test(ctx context.Context, a string, b int64) (context.Context, string, error) { + response, err := c.test(ctx, TestRequest{A: a, B: b}) + if err != nil { + return nil, "", err + } + r := response.(*TestResponse) + return r.Ctx, r.V, nil +} + +func NewClient(cc *grpc.ClientConn) Service { + return &clientBinding{ + test: grpctransport.NewClient( + cc, + "pb.Test", + "Test", + encodeRequest, + decodeResponse, + &pb.TestResponse{}, + grpctransport.ClientBefore(clientBefore), + grpctransport.ClientAfter(clientAfter), + ).Endpoint(), + } +} diff --git a/transport/grpc/_grpc_test/context_metadata.go b/transport/grpc/_grpc_test/context_metadata.go new file mode 100644 index 000000000..f31b50bdf --- /dev/null +++ b/transport/grpc/_grpc_test/context_metadata.go @@ -0,0 +1,106 @@ +package test + +import ( + "context" + "fmt" + "log" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type metaContext string + +const ( + correlationID metaContext = "correlation-id" + responseHDR metaContext = "my-response-header" + responseTRLR metaContext = "correlation-id-consumed" +) + +func clientBefore(ctx context.Context, md *metadata.MD) context.Context { + if hdr, ok := ctx.Value(correlationID).(string); ok { + (*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr) + } + if len(*md) > 0 { + fmt.Println("\tClient >> Request Headers:") + for key, val := range *md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + return ctx +} + +func serverBefore(ctx context.Context, md *metadata.MD) context.Context { + if len(*md) > 0 { + fmt.Println("\tServer << Request Headers:") + for key, val := range *md { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + if hdr, ok := (*md)[string(correlationID)]; ok { + cID := hdr[len(hdr)-1] + ctx = context.WithValue(ctx, correlationID, cID) + fmt.Printf("\tServer placed correlationID %q in context\n", cID) + } + return ctx +} + +func serverAfter(ctx context.Context, _ *metadata.MD) { + var mdHeader, mdTrailer metadata.MD + + mdHeader = metadata.Pairs(string(responseHDR), "has-a-value") + if err := grpc.SendHeader(ctx, mdHeader); err != nil { + log.Fatalf("unable to send header: %+v\n", err) + } + + if hdr, ok := ctx.Value(correlationID).(string); ok { + mdTrailer = metadata.Pairs(string(responseTRLR), hdr) + if err := grpc.SetTrailer(ctx, mdTrailer); err != nil { + log.Fatalf("unable to set trailer: %+v\n", err) + } + fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) + } + if len(mdHeader) > 0 { + fmt.Println("\tServer >> Response Headers:") + for key, val := range mdHeader { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + if len(mdTrailer) > 0 { + fmt.Println("\tServer >> Response Trailers:") + for key, val := range mdTrailer { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } +} + +func clientAfter(ctx context.Context, mdHeader metadata.MD, mdTrailer metadata.MD) context.Context { + if len(mdHeader) > 0 { + fmt.Println("\tClient << Response Headers:") + for key, val := range mdHeader { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + if len(mdTrailer) > 0 { + fmt.Println("\tClient << Response Trailers:") + for key, val := range mdTrailer { + fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) + } + } + + if hdr, ok := mdTrailer[string(responseTRLR)]; ok { + ctx = context.WithValue(ctx, responseTRLR, hdr[len(hdr)-1]) + } + return ctx +} + +func SetCorrelationID(ctx context.Context, v string) context.Context { + return context.WithValue(ctx, correlationID, v) +} + +func GetConsumedCorrelationID(ctx context.Context) string { + if trlr, ok := ctx.Value(responseTRLR).(string); ok { + return trlr + } + return "" +} diff --git a/transport/grpc/_grpc_test/request_response.go b/transport/grpc/_grpc_test/request_response.go new file mode 100644 index 000000000..441bc6501 --- /dev/null +++ b/transport/grpc/_grpc_test/request_response.go @@ -0,0 +1,40 @@ +package test + +import ( + "context" + "errors" + + pb "github.com/go-kit/kit/transport/grpc/_pb" +) + +func encodeRequest(ctx context.Context, req interface{}) (interface{}, error) { + r, ok := req.(TestRequest) + if !ok { + return nil, errors.New("request encode error") + } + return &pb.TestRequest{A: r.A, B: r.B}, nil +} + +func decodeRequest(ctx context.Context, req interface{}) (interface{}, error) { + r, ok := req.(*pb.TestRequest) + if !ok { + return nil, errors.New("request decode error") + } + return TestRequest{A: r.A, B: r.B}, nil +} + +func encodeResponse(ctx context.Context, resp interface{}) (interface{}, error) { + r, ok := resp.(*TestResponse) + if !ok { + return nil, errors.New("response encode error") + } + return &pb.TestResponse{V: r.V}, nil +} + +func decodeResponse(ctx context.Context, resp interface{}) (interface{}, error) { + r, ok := resp.(*pb.TestResponse) + if !ok { + return nil, errors.New("response decode error") + } + return &TestResponse{V: r.V, Ctx: ctx}, nil +} diff --git a/transport/grpc/_grpc_test/server.go b/transport/grpc/_grpc_test/server.go new file mode 100644 index 000000000..6c55b119a --- /dev/null +++ b/transport/grpc/_grpc_test/server.go @@ -0,0 +1,57 @@ +package test + +import ( + "context" + "fmt" + + oldcontext "golang.org/x/net/context" + + "github.com/go-kit/kit/endpoint" + grpctransport "github.com/go-kit/kit/transport/grpc" + pb "github.com/go-kit/kit/transport/grpc/_pb" +) + +type service struct{} + +func (service) Test(ctx context.Context, a string, b int64) (context.Context, string, error) { + return nil, fmt.Sprintf("%s = %d", a, b), nil +} + +func NewService() Service { + return service{} +} + +func makeTestEndpoint(svc Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(TestRequest) + newCtx, v, err := svc.Test(ctx, req.A, req.B) + return &TestResponse{ + V: v, + Ctx: newCtx, + }, err + } +} + +type serverBinding struct { + test grpctransport.Handler +} + +func (b *serverBinding) Test(ctx oldcontext.Context, req *pb.TestRequest) (*pb.TestResponse, error) { + _, response, err := b.test.ServeGRPC(ctx, req) + if err != nil { + return nil, err + } + return response.(*pb.TestResponse), nil +} + +func NewBinding(svc Service) *serverBinding { + return &serverBinding{ + test: grpctransport.NewServer( + makeTestEndpoint(svc), + decodeRequest, + encodeResponse, + grpctransport.ServerBefore(serverBefore), + grpctransport.ServerAfter(serverAfter), + ), + } +} diff --git a/transport/grpc/_grpc_test/service.go b/transport/grpc/_grpc_test/service.go new file mode 100644 index 000000000..536b27c0b --- /dev/null +++ b/transport/grpc/_grpc_test/service.go @@ -0,0 +1,17 @@ +package test + +import "context" + +type Service interface { + Test(ctx context.Context, a string, b int64) (context.Context, string, error) +} + +type TestRequest struct { + A string + B int64 +} + +type TestResponse struct { + Ctx context.Context + V string +} diff --git a/transport/grpc/_pb/generate.go b/transport/grpc/_pb/generate.go new file mode 100644 index 000000000..aa20bb664 --- /dev/null +++ b/transport/grpc/_pb/generate.go @@ -0,0 +1,3 @@ +package pb + +//go:generate protoc test.proto --go_out=plugins=grpc:. diff --git a/transport/grpc/_pb/test.pb.go b/transport/grpc/_pb/test.pb.go new file mode 100644 index 000000000..97d29bb1e --- /dev/null +++ b/transport/grpc/_pb/test.pb.go @@ -0,0 +1,167 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package pb is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + TestRequest + TestResponse +*/ +package pb + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type TestRequest struct { + A string `protobuf:"bytes,1,opt,name=a" json:"a,omitempty"` + B int64 `protobuf:"varint,2,opt,name=b" json:"b,omitempty"` +} + +func (m *TestRequest) Reset() { *m = TestRequest{} } +func (m *TestRequest) String() string { return proto.CompactTextString(m) } +func (*TestRequest) ProtoMessage() {} +func (*TestRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *TestRequest) GetA() string { + if m != nil { + return m.A + } + return "" +} + +func (m *TestRequest) GetB() int64 { + if m != nil { + return m.B + } + return 0 +} + +type TestResponse struct { + V string `protobuf:"bytes,1,opt,name=v" json:"v,omitempty"` +} + +func (m *TestResponse) Reset() { *m = TestResponse{} } +func (m *TestResponse) String() string { return proto.CompactTextString(m) } +func (*TestResponse) ProtoMessage() {} +func (*TestResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *TestResponse) GetV() string { + if m != nil { + return m.V + } + return "" +} + +func init() { + proto.RegisterType((*TestRequest)(nil), "pb.TestRequest") + proto.RegisterType((*TestResponse)(nil), "pb.TestResponse") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for Test service + +type TestClient interface { + Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) +} + +type testClient struct { + cc *grpc.ClientConn +} + +func NewTestClient(cc *grpc.ClientConn) TestClient { + return &testClient{cc} +} + +func (c *testClient) Test(ctx context.Context, in *TestRequest, opts ...grpc.CallOption) (*TestResponse, error) { + out := new(TestResponse) + err := grpc.Invoke(ctx, "/pb.Test/Test", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Test service + +type TestServer interface { + Test(context.Context, *TestRequest) (*TestResponse, error) +} + +func RegisterTestServer(s *grpc.Server, srv TestServer) { + s.RegisterService(&_Test_serviceDesc, srv) +} + +func _Test_Test_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TestRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServer).Test(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pb.Test/Test", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServer).Test(ctx, req.(*TestRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Test_serviceDesc = grpc.ServiceDesc{ + ServiceName: "pb.Test", + HandlerType: (*TestServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Test", + Handler: _Test_Test_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "test.proto", +} + +func init() { proto.RegisterFile("test.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 129 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, + 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xd2, 0xe4, 0xe2, 0x0e, 0x49, + 0x2d, 0x2e, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0xe2, 0xe1, 0x62, 0x4c, 0x94, 0x60, + 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x62, 0x4c, 0x04, 0xf1, 0x92, 0x24, 0x98, 0x14, 0x18, 0x35, 0x98, + 0x83, 0x18, 0x93, 0x94, 0x64, 0xb8, 0x78, 0x20, 0x4a, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x41, + 0xb2, 0x65, 0x30, 0xb5, 0x65, 0x46, 0xc6, 0x5c, 0x2c, 0x20, 0x59, 0x21, 0x6d, 0x28, 0xcd, 0xaf, + 0x57, 0x90, 0xa4, 0x87, 0x64, 0xb4, 0x94, 0x00, 0x42, 0x00, 0x62, 0x80, 0x12, 0x43, 0x12, 0x1b, + 0xd8, 0x21, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x49, 0xfc, 0xd8, 0xf1, 0x96, 0x00, 0x00, + 0x00, +} diff --git a/transport/grpc/_pb/test.proto b/transport/grpc/_pb/test.proto new file mode 100644 index 000000000..6a3555e3c --- /dev/null +++ b/transport/grpc/_pb/test.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package pb; + +service Test { + rpc Test (TestRequest) returns (TestResponse) {} +} + +message TestRequest { + string a = 1; + int64 b = 2; +} + +message TestResponse { + string v = 1; +} diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 437d931c1..622ca883d 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -107,7 +107,7 @@ func (c Client) Endpoint() endpoint.Endpoint { } for _, f := range c.after { - ctx = f(ctx, &header, &trailer) + ctx = f(ctx, header, trailer) } response, err := c.dec(ctx, grpcReply) diff --git a/transport/grpc/client_test.go b/transport/grpc/client_test.go new file mode 100644 index 000000000..bbd2f5ee6 --- /dev/null +++ b/transport/grpc/client_test.go @@ -0,0 +1,59 @@ +package grpc_test + +import ( + "context" + "fmt" + "net" + "testing" + + "google.golang.org/grpc" + + test "github.com/go-kit/kit/transport/grpc/_grpc_test" + pb "github.com/go-kit/kit/transport/grpc/_pb" +) + +const ( + hostPort string = "localhost:8002" +) + +func TestGRPCClient(t *testing.T) { + var ( + server = grpc.NewServer() + service = test.NewService() + ) + + sc, err := net.Listen("tcp", hostPort) + if err != nil { + t.Fatalf("unable to listen: %+v", err) + } + defer server.GracefulStop() + + go func() { + pb.RegisterTestServer(server, test.NewBinding(service)) + _ = server.Serve(sc) + }() + + cc, err := grpc.Dial(hostPort, grpc.WithInsecure()) + if err != nil { + t.Fatalf("unable to Dial: %+v", err) + } + + client := test.NewClient(cc) + + var ( + a = "the answer to life the universe and everything" + b = int64(42) + cID = "request-1" + ctx = test.SetCorrelationID(context.Background(), cID) + ) + + responseCTX, v, err := client.Test(ctx, a, b) + + if want, have := fmt.Sprintf("%s = %d", a, b), v; want != have { + t.Fatalf("want %q, have %q", want, have) + } + + if want, have := cID, test.GetConsumedCorrelationID(responseCTX); want != have { + t.Fatalf("want %q, have %q", want, have) + } +} diff --git a/transport/grpc/request_response_funcs.go b/transport/grpc/request_response_funcs.go index 067ef3f4b..7192bb507 100644 --- a/transport/grpc/request_response_funcs.go +++ b/transport/grpc/request_response_funcs.go @@ -27,7 +27,7 @@ type ServerResponseFunc func(context.Context, *metadata.MD) // trailer and make the responses available for consumption. ClientResponseFuncs // are only executed in clients, after a request has been made, but prior to it // being decoded. -type ClientResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context +type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context // SetResponseHeader returns a ResponseFunc that sets the specified metadata // key-value pair. diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 9f6a94a12..9289c4f29 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -1,8 +1,6 @@ package grpc import ( - "context" - oldcontext "golang.org/x/net/context" "google.golang.org/grpc/metadata" @@ -19,7 +17,6 @@ type Handler interface { // Server wraps an endpoint and implements grpc.Handler. type Server struct { - ctx context.Context e endpoint.Endpoint dec DecodeRequestFunc enc EncodeResponseFunc @@ -34,14 +31,12 @@ type Server struct { // definitions to individual handlers. Request and response objects are from the // caller business domain, not gRPC request and reply types. func NewServer( - ctx context.Context, e endpoint.Endpoint, dec DecodeRequestFunc, enc EncodeResponseFunc, options ...ServerOption, ) *Server { s := &Server{ - ctx: ctx, e: e, dec: dec, enc: enc, @@ -75,11 +70,9 @@ func ServerErrorLogger(logger log.Logger) ServerOption { } // ServeGRPC implements the Handler interface. -func (s Server) ServeGRPC(grpcCtx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) { - ctx := s.ctx - +func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.Context, interface{}, error) { // Retrieve gRPC metadata. - md, ok := metadata.FromContext(grpcCtx) + md, ok := metadata.FromContext(ctx) if !ok { md = metadata.MD{} } @@ -89,18 +82,18 @@ func (s Server) ServeGRPC(grpcCtx oldcontext.Context, req interface{}) (oldconte } // Store potentially updated metadata in the gRPC context. - grpcCtx = metadata.NewContext(grpcCtx, md) + ctx = metadata.NewContext(ctx, md) - request, err := s.dec(grpcCtx, req) + request, err := s.dec(ctx, req) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err } response, err := s.e(ctx, request) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err } for _, f := range s.after { @@ -108,13 +101,13 @@ func (s Server) ServeGRPC(grpcCtx oldcontext.Context, req interface{}) (oldconte } // Store potentially updated metadata in the gRPC context. - grpcCtx = metadata.NewContext(grpcCtx, md) + ctx = metadata.NewContext(ctx, md) - grpcResp, err := s.enc(grpcCtx, response) + grpcResp, err := s.enc(ctx, response) if err != nil { s.logger.Log("err", err) - return grpcCtx, nil, err + return ctx, nil, err } - return grpcCtx, grpcResp, nil + return ctx, grpcResp, nil } From 01974e5478ff043c0d708f8c89526428161a754e Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Sun, 5 Mar 2017 10:06:46 +0100 Subject: [PATCH 3/7] panic on codec error is better as it's a coding error --- transport/grpc/_grpc_test/request_response.go | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/transport/grpc/_grpc_test/request_response.go b/transport/grpc/_grpc_test/request_response.go index 441bc6501..13c2cb1e4 100644 --- a/transport/grpc/_grpc_test/request_response.go +++ b/transport/grpc/_grpc_test/request_response.go @@ -2,39 +2,26 @@ package test import ( "context" - "errors" pb "github.com/go-kit/kit/transport/grpc/_pb" ) func encodeRequest(ctx context.Context, req interface{}) (interface{}, error) { - r, ok := req.(TestRequest) - if !ok { - return nil, errors.New("request encode error") - } + r := req.(TestRequest) return &pb.TestRequest{A: r.A, B: r.B}, nil } func decodeRequest(ctx context.Context, req interface{}) (interface{}, error) { - r, ok := req.(*pb.TestRequest) - if !ok { - return nil, errors.New("request decode error") - } + r := req.(*pb.TestRequest) return TestRequest{A: r.A, B: r.B}, nil } func encodeResponse(ctx context.Context, resp interface{}) (interface{}, error) { - r, ok := resp.(*TestResponse) - if !ok { - return nil, errors.New("response encode error") - } + r := resp.(*TestResponse) return &pb.TestResponse{V: r.V}, nil } func decodeResponse(ctx context.Context, resp interface{}) (interface{}, error) { - r, ok := resp.(*pb.TestResponse) - if !ok { - return nil, errors.New("response decode error") - } + r := resp.(*pb.TestResponse) return &TestResponse{V: r.V, Ctx: ctx}, nil } From 550ed03f7ca085c60151d9e527001195abcb55b5 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Sun, 5 Mar 2017 10:59:58 +0100 Subject: [PATCH 4/7] refactored ServerResponseFunc to be more symmetrical with ClientResponseFunc --- transport/grpc/_grpc_test/client.go | 15 ++- transport/grpc/_grpc_test/context_metadata.go | 104 +++++++++++------- transport/grpc/_grpc_test/server.go | 17 ++- transport/grpc/request_response_funcs.go | 26 +++-- transport/grpc/server.go | 21 +++- 5 files changed, 130 insertions(+), 53 deletions(-) diff --git a/transport/grpc/_grpc_test/client.go b/transport/grpc/_grpc_test/client.go index 11d78ca7f..a70ebc295 100644 --- a/transport/grpc/_grpc_test/client.go +++ b/transport/grpc/_grpc_test/client.go @@ -32,8 +32,19 @@ func NewClient(cc *grpc.ClientConn) Service { encodeRequest, decodeResponse, &pb.TestResponse{}, - grpctransport.ClientBefore(clientBefore), - grpctransport.ClientAfter(clientAfter), + grpctransport.ClientBefore( + injectCorrelationID, + ), + grpctransport.ClientBefore( + displayClientRequestHeaders, + ), + grpctransport.ClientAfter( + displayClientResponseHeaders, + displayClientResponseTrailers, + ), + grpctransport.ClientAfter( + extractConsumedCorrelationID, + ), ).Endpoint(), } } diff --git a/transport/grpc/_grpc_test/context_metadata.go b/transport/grpc/_grpc_test/context_metadata.go index f31b50bdf..5bd545dfe 100644 --- a/transport/grpc/_grpc_test/context_metadata.go +++ b/transport/grpc/_grpc_test/context_metadata.go @@ -3,24 +3,30 @@ package test import ( "context" "fmt" - "log" - "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) type metaContext string const ( - correlationID metaContext = "correlation-id" - responseHDR metaContext = "my-response-header" - responseTRLR metaContext = "correlation-id-consumed" + correlationID metaContext = "correlation-id" + responseHDR metaContext = "my-response-header" + responseTRLR metaContext = "my-response-trailer" + correlationIDTRLR metaContext = "correlation-id-consumed" ) -func clientBefore(ctx context.Context, md *metadata.MD) context.Context { +/* client before functions */ + +func injectCorrelationID(ctx context.Context, md *metadata.MD) context.Context { if hdr, ok := ctx.Value(correlationID).(string); ok { + fmt.Printf("\tClient found correlationID %q in context, set metadata header\n", hdr) (*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr) } + return ctx +} + +func displayClientRequestHeaders(ctx context.Context, md *metadata.MD) context.Context { if len(*md) > 0 { fmt.Println("\tClient >> Request Headers:") for key, val := range *md { @@ -30,76 +36,100 @@ func clientBefore(ctx context.Context, md *metadata.MD) context.Context { return ctx } -func serverBefore(ctx context.Context, md *metadata.MD) context.Context { +/* server before functions */ + +func extractCorrelationID(ctx context.Context, md *metadata.MD) context.Context { + if hdr, ok := (*md)[string(correlationID)]; ok { + cID := hdr[len(hdr)-1] + ctx = context.WithValue(ctx, correlationID, cID) + fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID) + } + return ctx +} + +func displayServerRequestHeaders(ctx context.Context, md *metadata.MD) context.Context { if len(*md) > 0 { fmt.Println("\tServer << Request Headers:") for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } - if hdr, ok := (*md)[string(correlationID)]; ok { - cID := hdr[len(hdr)-1] - ctx = context.WithValue(ctx, correlationID, cID) - fmt.Printf("\tServer placed correlationID %q in context\n", cID) - } return ctx } -func serverAfter(ctx context.Context, _ *metadata.MD) { - var mdHeader, mdTrailer metadata.MD +/* server after functions */ - mdHeader = metadata.Pairs(string(responseHDR), "has-a-value") - if err := grpc.SendHeader(ctx, mdHeader); err != nil { - log.Fatalf("unable to send header: %+v\n", err) - } +func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) { + *md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value")) +} - if hdr, ok := ctx.Value(correlationID).(string); ok { - mdTrailer = metadata.Pairs(string(responseTRLR), hdr) - if err := grpc.SetTrailer(ctx, mdTrailer); err != nil { - log.Fatalf("unable to set trailer: %+v\n", err) - } - fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) - } - if len(mdHeader) > 0 { +func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) { + if len(*md) > 0 { fmt.Println("\tServer >> Response Headers:") - for key, val := range mdHeader { + for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } - if len(mdTrailer) > 0 { +} + +func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) { + *md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too")) +} + +func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) { + if hdr, ok := ctx.Value(correlationID).(string); ok { + fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) + *md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr)) + } +} + +func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) { + if len(*md) > 0 { fmt.Println("\tServer >> Response Trailers:") - for key, val := range mdTrailer { + for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } } -func clientAfter(ctx context.Context, mdHeader metadata.MD, mdTrailer metadata.MD) context.Context { - if len(mdHeader) > 0 { +/* client after functions */ + +func displayClientResponseHeaders(ctx context.Context, md metadata.MD, _ metadata.MD) context.Context { + if len(md) > 0 { fmt.Println("\tClient << Response Headers:") - for key, val := range mdHeader { + for key, val := range md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } - if len(mdTrailer) > 0 { + return ctx +} + +func displayClientResponseTrailers(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context { + if len(md) > 0 { fmt.Println("\tClient << Response Trailers:") - for key, val := range mdTrailer { + for key, val := range md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } + return ctx +} - if hdr, ok := mdTrailer[string(responseTRLR)]; ok { - ctx = context.WithValue(ctx, responseTRLR, hdr[len(hdr)-1]) +func extractConsumedCorrelationID(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context { + if hdr, ok := md[string(correlationIDTRLR)]; ok { + fmt.Printf("\tClient received consumed correlationID %q in metadata trailer, set context\n", hdr[len(hdr)-1]) + ctx = context.WithValue(ctx, correlationIDTRLR, hdr[len(hdr)-1]) } return ctx } +/* CorrelationID context handlers */ + func SetCorrelationID(ctx context.Context, v string) context.Context { return context.WithValue(ctx, correlationID, v) } func GetConsumedCorrelationID(ctx context.Context) string { - if trlr, ok := ctx.Value(responseTRLR).(string); ok { + if trlr, ok := ctx.Value(correlationIDTRLR).(string); ok { return trlr } return "" diff --git a/transport/grpc/_grpc_test/server.go b/transport/grpc/_grpc_test/server.go index 6c55b119a..52e904870 100644 --- a/transport/grpc/_grpc_test/server.go +++ b/transport/grpc/_grpc_test/server.go @@ -50,8 +50,21 @@ func NewBinding(svc Service) *serverBinding { makeTestEndpoint(svc), decodeRequest, encodeResponse, - grpctransport.ServerBefore(serverBefore), - grpctransport.ServerAfter(serverAfter), + grpctransport.ServerBefore( + extractCorrelationID, + ), + grpctransport.ServerBefore( + displayServerRequestHeaders, + ), + grpctransport.ServerAfter( + injectResponseHeader, + injectResponseTrailer, + injectConsumedCorrelationID, + ), + grpctransport.ServerAfter( + displayServerResponseHeaders, + displayServerResponseTrailers, + ), ), } } diff --git a/transport/grpc/request_response_funcs.go b/transport/grpc/request_response_funcs.go index 7192bb507..05a9d34a2 100644 --- a/transport/grpc/request_response_funcs.go +++ b/transport/grpc/request_response_funcs.go @@ -19,9 +19,10 @@ const ( type RequestFunc func(context.Context, *metadata.MD) context.Context // ServerResponseFunc may take information from a request context and use it to -// manipulate the gRPC metadata header. ResponseFuncs are only executed in -// servers, after invoking the endpoint but prior to writing a response. -type ServerResponseFunc func(context.Context, *metadata.MD) +// manipulate the gRPC response metadata headers and trailers. ResponseFuncs are +// only executed in servers, after invoking the endpoint but prior to writing a +// response. +type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) // ClientResponseFunc may take information from a gRPC metadata header and/or // trailer and make the responses available for consumption. ClientResponseFuncs @@ -29,22 +30,31 @@ type ServerResponseFunc func(context.Context, *metadata.MD) // being decoded. type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context +// SetRequestHeader returns a RequestFunc that sets the specified metadata +// key-value pair. +func SetRequestHeader(key, val string) RequestFunc { + return func(ctx context.Context, md *metadata.MD) context.Context { + key, val := EncodeKeyValue(key, val) + (*md)[key] = append((*md)[key], val) + return ctx + } +} + // SetResponseHeader returns a ResponseFunc that sets the specified metadata // key-value pair. func SetResponseHeader(key, val string) ServerResponseFunc { - return func(_ context.Context, md *metadata.MD) { + return func(_ context.Context, md *metadata.MD, _ *metadata.MD) { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) } } -// SetRequestHeader returns a RequestFunc that sets the specified metadata +// SetResponseTrailer returns a ResponseFunc that sets the specified metadata // key-value pair. -func SetRequestHeader(key, val string) RequestFunc { - return func(ctx context.Context, md *metadata.MD) context.Context { +func SetResponseTrailer(key, val string) ServerResponseFunc { + return func(_ context.Context, _ *metadata.MD, md *metadata.MD) { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) - return ctx } } diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 9289c4f29..476902eb9 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -2,6 +2,7 @@ package grpc import ( oldcontext "golang.org/x/net/context" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/go-kit/kit/endpoint" @@ -96,18 +97,30 @@ func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.C return ctx, nil, err } + var mdHeader, mdTrailer metadata.MD for _, f := range s.after { - f(ctx, &md) + f(ctx, &mdHeader, &mdTrailer) } - // Store potentially updated metadata in the gRPC context. - ctx = metadata.NewContext(ctx, md) - grpcResp, err := s.enc(ctx, response) if err != nil { s.logger.Log("err", err) return ctx, nil, err } + if len(mdHeader) > 0 { + if err = grpc.SendHeader(ctx, mdHeader); err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } + } + + if len(mdTrailer) > 0 { + if err = grpc.SetTrailer(ctx, mdTrailer); err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } + } + return ctx, grpcResp, nil } From 78ef6c860b12b8bf1e8f28f2cecf8d34eb63ac7b Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Sun, 5 Mar 2017 11:32:37 +0100 Subject: [PATCH 5/7] moved protobuf package for test inside the grpc test folder --- transport/grpc/_grpc_test/client.go | 2 +- transport/grpc/{_pb => _grpc_test/pb}/generate.go | 0 transport/grpc/{_pb => _grpc_test/pb}/test.pb.go | 0 transport/grpc/{_pb => _grpc_test/pb}/test.proto | 0 transport/grpc/_grpc_test/request_response.go | 2 +- transport/grpc/_grpc_test/server.go | 2 +- transport/grpc/client_test.go | 2 +- 7 files changed, 4 insertions(+), 4 deletions(-) rename transport/grpc/{_pb => _grpc_test/pb}/generate.go (100%) rename transport/grpc/{_pb => _grpc_test/pb}/test.pb.go (100%) rename transport/grpc/{_pb => _grpc_test/pb}/test.proto (100%) diff --git a/transport/grpc/_grpc_test/client.go b/transport/grpc/_grpc_test/client.go index a70ebc295..1e0c8a78e 100644 --- a/transport/grpc/_grpc_test/client.go +++ b/transport/grpc/_grpc_test/client.go @@ -7,7 +7,7 @@ import ( "github.com/go-kit/kit/endpoint" grpctransport "github.com/go-kit/kit/transport/grpc" - pb "github.com/go-kit/kit/transport/grpc/_pb" + "github.com/go-kit/kit/transport/grpc/_grpc_test/pb" ) type clientBinding struct { diff --git a/transport/grpc/_pb/generate.go b/transport/grpc/_grpc_test/pb/generate.go similarity index 100% rename from transport/grpc/_pb/generate.go rename to transport/grpc/_grpc_test/pb/generate.go diff --git a/transport/grpc/_pb/test.pb.go b/transport/grpc/_grpc_test/pb/test.pb.go similarity index 100% rename from transport/grpc/_pb/test.pb.go rename to transport/grpc/_grpc_test/pb/test.pb.go diff --git a/transport/grpc/_pb/test.proto b/transport/grpc/_grpc_test/pb/test.proto similarity index 100% rename from transport/grpc/_pb/test.proto rename to transport/grpc/_grpc_test/pb/test.proto diff --git a/transport/grpc/_grpc_test/request_response.go b/transport/grpc/_grpc_test/request_response.go index 13c2cb1e4..269703d39 100644 --- a/transport/grpc/_grpc_test/request_response.go +++ b/transport/grpc/_grpc_test/request_response.go @@ -3,7 +3,7 @@ package test import ( "context" - pb "github.com/go-kit/kit/transport/grpc/_pb" + "github.com/go-kit/kit/transport/grpc/_grpc_test/pb" ) func encodeRequest(ctx context.Context, req interface{}) (interface{}, error) { diff --git a/transport/grpc/_grpc_test/server.go b/transport/grpc/_grpc_test/server.go index 52e904870..49e70a91f 100644 --- a/transport/grpc/_grpc_test/server.go +++ b/transport/grpc/_grpc_test/server.go @@ -8,7 +8,7 @@ import ( "github.com/go-kit/kit/endpoint" grpctransport "github.com/go-kit/kit/transport/grpc" - pb "github.com/go-kit/kit/transport/grpc/_pb" + "github.com/go-kit/kit/transport/grpc/_grpc_test/pb" ) type service struct{} diff --git a/transport/grpc/client_test.go b/transport/grpc/client_test.go index bbd2f5ee6..e4cac1d8c 100644 --- a/transport/grpc/client_test.go +++ b/transport/grpc/client_test.go @@ -9,7 +9,7 @@ import ( "google.golang.org/grpc" test "github.com/go-kit/kit/transport/grpc/_grpc_test" - pb "github.com/go-kit/kit/transport/grpc/_pb" + "github.com/go-kit/kit/transport/grpc/_grpc_test/pb" ) const ( From 4079d84025272c275eb015a8e5cd9ad44d526367 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Sun, 5 Mar 2017 14:58:19 +0100 Subject: [PATCH 6/7] Improved gRPC Request / Response Funcs - Separated ClientRequestFunc and ServerRequestFunc to highlight that request metadata in a ServerRequestFunc is supposed to be immutable. - Return context in ServerResponseFuncs like with the HTTP transport, to allow passing data between chained serverAfter middlewares and finally to the encoding step so it can be used to alter the response payload prior to sending this response to the client. --- transport/grpc/_grpc_test/context_metadata.go | 25 ++++++++++------- transport/grpc/client.go | 6 ++--- transport/grpc/request_response_funcs.go | 27 ++++++++++++------- transport/grpc/server.go | 11 +++----- 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/transport/grpc/_grpc_test/context_metadata.go b/transport/grpc/_grpc_test/context_metadata.go index 5bd545dfe..0769325e2 100644 --- a/transport/grpc/_grpc_test/context_metadata.go +++ b/transport/grpc/_grpc_test/context_metadata.go @@ -38,8 +38,8 @@ func displayClientRequestHeaders(ctx context.Context, md *metadata.MD) context.C /* server before functions */ -func extractCorrelationID(ctx context.Context, md *metadata.MD) context.Context { - if hdr, ok := (*md)[string(correlationID)]; ok { +func extractCorrelationID(ctx context.Context, md metadata.MD) context.Context { + if hdr, ok := md[string(correlationID)]; ok { cID := hdr[len(hdr)-1] ctx = context.WithValue(ctx, correlationID, cID) fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID) @@ -47,10 +47,10 @@ func extractCorrelationID(ctx context.Context, md *metadata.MD) context.Context return ctx } -func displayServerRequestHeaders(ctx context.Context, md *metadata.MD) context.Context { - if len(*md) > 0 { +func displayServerRequestHeaders(ctx context.Context, md metadata.MD) context.Context { + if len(md) > 0 { fmt.Println("\tServer << Request Headers:") - for key, val := range *md { + for key, val := range md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } @@ -59,37 +59,42 @@ func displayServerRequestHeaders(ctx context.Context, md *metadata.MD) context.C /* server after functions */ -func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) { +func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { *md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value")) + return ctx } -func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) { +func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { if len(*md) > 0 { fmt.Println("\tServer >> Response Headers:") for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } + return ctx } -func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) { +func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { *md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too")) + return ctx } -func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) { +func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { if hdr, ok := ctx.Value(correlationID).(string); ok { fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr) *md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr)) } + return ctx } -func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) { +func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { if len(*md) > 0 { fmt.Println("\tServer >> Response Trailers:") for key, val := range *md { fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1]) } } + return ctx } /* client after functions */ diff --git a/transport/grpc/client.go b/transport/grpc/client.go index 622ca883d..c0faa2b36 100644 --- a/transport/grpc/client.go +++ b/transport/grpc/client.go @@ -21,7 +21,7 @@ type Client struct { enc EncodeRequestFunc dec DecodeResponseFunc grpcReply reflect.Type - before []RequestFunc + before []ClientRequestFunc after []ClientResponseFunc } @@ -54,7 +54,7 @@ func NewClient( reflect.ValueOf(grpcReply), ).Interface(), ), - before: []RequestFunc{}, + before: []ClientRequestFunc{}, after: []ClientResponseFunc{}, } for _, option := range options { @@ -68,7 +68,7 @@ type ClientOption func(*Client) // ClientBefore sets the RequestFuncs that are applied to the outgoing gRPC // request before it's invoked. -func ClientBefore(before ...RequestFunc) ClientOption { +func ClientBefore(before ...ClientRequestFunc) ClientOption { return func(c *Client) { c.before = append(c.before, before...) } } diff --git a/transport/grpc/request_response_funcs.go b/transport/grpc/request_response_funcs.go index 05a9d34a2..8d072ede7 100644 --- a/transport/grpc/request_response_funcs.go +++ b/transport/grpc/request_response_funcs.go @@ -12,17 +12,22 @@ const ( binHdrSuffix = "-bin" ) -// RequestFunc may take information from a gRPC request and put it into a -// request context. In Servers, RequestFuncs are executed prior to invoking the -// endpoint. In Clients, RequestFuncs are executed after creating the request -// but prior to invoking the gRPC client. -type RequestFunc func(context.Context, *metadata.MD) context.Context +// ClientRequestFunc may take information from context and use it to construct +// metadata headers to be transported to the server. ClientRequestFuncs are +// executed after creating the request but prior to sending the gRPC request to +// the server. +type ClientRequestFunc func(context.Context, *metadata.MD) context.Context + +// ServerRequestFunc may take information from the received metadata header and +// use it to place items in the request scoped context. ServerRequestFuncs are +// executed prior to invoking the endpoint. +type ServerRequestFunc func(context.Context, metadata.MD) context.Context // ServerResponseFunc may take information from a request context and use it to // manipulate the gRPC response metadata headers and trailers. ResponseFuncs are // only executed in servers, after invoking the endpoint but prior to writing a // response. -type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) +type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer *metadata.MD) context.Context // ClientResponseFunc may take information from a gRPC metadata header and/or // trailer and make the responses available for consumption. ClientResponseFuncs @@ -30,9 +35,9 @@ type ServerResponseFunc func(ctx context.Context, header *metadata.MD, trailer * // being decoded. type ClientResponseFunc func(ctx context.Context, header metadata.MD, trailer metadata.MD) context.Context -// SetRequestHeader returns a RequestFunc that sets the specified metadata +// SetRequestHeader returns a ClientRequestFunc that sets the specified metadata // key-value pair. -func SetRequestHeader(key, val string) RequestFunc { +func SetRequestHeader(key, val string) ClientRequestFunc { return func(ctx context.Context, md *metadata.MD) context.Context { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) @@ -43,18 +48,20 @@ func SetRequestHeader(key, val string) RequestFunc { // SetResponseHeader returns a ResponseFunc that sets the specified metadata // key-value pair. func SetResponseHeader(key, val string) ServerResponseFunc { - return func(_ context.Context, md *metadata.MD, _ *metadata.MD) { + return func(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) + return ctx } } // SetResponseTrailer returns a ResponseFunc that sets the specified metadata // key-value pair. func SetResponseTrailer(key, val string) ServerResponseFunc { - return func(_ context.Context, _ *metadata.MD, md *metadata.MD) { + return func(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context { key, val := EncodeKeyValue(key, val) (*md)[key] = append((*md)[key], val) + return ctx } } diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 476902eb9..b14d7d8db 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -21,7 +21,7 @@ type Server struct { e endpoint.Endpoint dec DecodeRequestFunc enc EncodeResponseFunc - before []RequestFunc + before []ServerRequestFunc after []ServerResponseFunc logger log.Logger } @@ -54,7 +54,7 @@ type ServerOption func(*Server) // ServerBefore functions are executed on the HTTP request object before the // request is decoded. -func ServerBefore(before ...RequestFunc) ServerOption { +func ServerBefore(before ...ServerRequestFunc) ServerOption { return func(s *Server) { s.before = append(s.before, before...) } } @@ -79,12 +79,9 @@ func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.C } for _, f := range s.before { - ctx = f(ctx, &md) + ctx = f(ctx, md) } - // Store potentially updated metadata in the gRPC context. - ctx = metadata.NewContext(ctx, md) - request, err := s.dec(ctx, req) if err != nil { s.logger.Log("err", err) @@ -99,7 +96,7 @@ func (s Server) ServeGRPC(ctx oldcontext.Context, req interface{}) (oldcontext.C var mdHeader, mdTrailer metadata.MD for _, f := range s.after { - f(ctx, &mdHeader, &mdTrailer) + ctx = f(ctx, &mdHeader, &mdTrailer) } grpcResp, err := s.enc(ctx, response) From 8864ce8767b3268392f405f599e9dca8fd9a8b96 Mon Sep 17 00:00:00 2001 From: Bas van Beek Date: Sun, 5 Mar 2017 15:10:56 +0100 Subject: [PATCH 7/7] updates to reflect change in ServerRequestFunc --- auth/jwt/transport.go | 8 ++++---- auth/jwt/transport_test.go | 6 +++--- tracing/opentracing/grpc.go | 6 +++--- tracing/opentracing/grpc_test.go | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/auth/jwt/transport.go b/auth/jwt/transport.go index 6d02f7dc5..7be7db417 100644 --- a/auth/jwt/transport.go +++ b/auth/jwt/transport.go @@ -44,10 +44,10 @@ func FromHTTPContext() http.RequestFunc { // ToGRPCContext moves JWT token from grpc metadata to context. Particularly // userful for servers. -func ToGRPCContext() grpc.RequestFunc { - return func(ctx context.Context, md *metadata.MD) context.Context { +func ToGRPCContext() grpc.ServerRequestFunc { + return func(ctx context.Context, md metadata.MD) context.Context { // capital "Key" is illegal in HTTP/2. - authHeader, ok := (*md)["authorization"] + authHeader, ok := md["authorization"] if !ok { return ctx } @@ -63,7 +63,7 @@ func ToGRPCContext() grpc.RequestFunc { // FromGRPCContext moves JWT token from context to grpc metadata. Particularly // useful for clients. -func FromGRPCContext() grpc.RequestFunc { +func FromGRPCContext() grpc.ClientRequestFunc { return func(ctx context.Context, md *metadata.MD) context.Context { token, ok := ctx.Value(JWTTokenContextKey).(string) if ok { diff --git a/auth/jwt/transport_test.go b/auth/jwt/transport_test.go index 8b8922a6a..b04d76feb 100644 --- a/auth/jwt/transport_test.go +++ b/auth/jwt/transport_test.go @@ -69,7 +69,7 @@ func TestToGRPCContext(t *testing.T) { reqFunc := ToGRPCContext() // No Authorization header is passed - ctx := reqFunc(context.Background(), &md) + ctx := reqFunc(context.Background(), md) token := ctx.Value(JWTTokenContextKey) if token != nil { t.Error("Context should not contain a JWT Token") @@ -77,7 +77,7 @@ func TestToGRPCContext(t *testing.T) { // Invalid Authorization header is passed md["authorization"] = []string{fmt.Sprintf("%s", signedKey)} - ctx = reqFunc(context.Background(), &md) + ctx = reqFunc(context.Background(), md) token = ctx.Value(JWTTokenContextKey) if token != nil { t.Error("Context should not contain a JWT Token") @@ -85,7 +85,7 @@ func TestToGRPCContext(t *testing.T) { // Authorization header is correct md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)} - ctx = reqFunc(context.Background(), &md) + ctx = reqFunc(context.Background(), md) token, ok := ctx.Value(JWTTokenContextKey).(string) if !ok { t.Fatal("JWT Token not passed to context correctly") diff --git a/tracing/opentracing/grpc.go b/tracing/opentracing/grpc.go index 56eb143f5..fa4544009 100644 --- a/tracing/opentracing/grpc.go +++ b/tracing/opentracing/grpc.go @@ -32,10 +32,10 @@ func ToGRPCRequest(tracer opentracing.Tracer, logger log.Logger) func(ctx contex // `operationName` accordingly. If no trace could be found in `req`, the Span // will be a trace root. The Span is incorporated in the returned Context and // can be retrieved with opentracing.SpanFromContext(ctx). -func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md *metadata.MD) context.Context { - return func(ctx context.Context, md *metadata.MD) context.Context { +func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md metadata.MD) context.Context { + return func(ctx context.Context, md metadata.MD) context.Context { var span opentracing.Span - wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{md}) + wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{&md}) if err != nil && err != opentracing.ErrSpanContextNotFound { logger.Log("err", err) } diff --git a/tracing/opentracing/grpc_test.go b/tracing/opentracing/grpc_test.go index 96a834fd8..3d07a14aa 100644 --- a/tracing/opentracing/grpc_test.go +++ b/tracing/opentracing/grpc_test.go @@ -41,7 +41,7 @@ func TestTraceGRPCRequestRoundtrip(t *testing.T) { // Use FromGRPCRequest to verify that we can join with the trace given MD. fromGRPCFunc := kitot.FromGRPCRequest(tracer, "joined", logger) - joinCtx := fromGRPCFunc(afterCtx, &md) + joinCtx := fromGRPCFunc(afterCtx, md) joinedSpan := opentracing.SpanFromContext(joinCtx).(*mocktracer.MockSpan) joinedContext := joinedSpan.Context().(mocktracer.MockSpanContext)