Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for gRPC Server Response Headers and Trailers, adding ClientAfter handler #479

Merged
merged 7 commits into from
Mar 5, 2017
8 changes: 4 additions & 4 deletions auth/jwt/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ func FromHTTPContext() http.RequestFunc {

// ToGRPCContext moves JWT token from grpc metadata to context. Particularly
// userful for servers.
func ToGRPCContext() grpc.RequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
func ToGRPCContext() grpc.ServerRequestFunc {
return func(ctx context.Context, md metadata.MD) context.Context {
// capital "Key" is illegal in HTTP/2.
authHeader, ok := (*md)["authorization"]
authHeader, ok := md["authorization"]
if !ok {
return ctx
}
Expand All @@ -63,7 +63,7 @@ func ToGRPCContext() grpc.RequestFunc {

// FromGRPCContext moves JWT token from context to grpc metadata. Particularly
// useful for clients.
func FromGRPCContext() grpc.RequestFunc {
func FromGRPCContext() grpc.ClientRequestFunc {
return func(ctx context.Context, md *metadata.MD) context.Context {
token, ok := ctx.Value(JWTTokenContextKey).(string)
if ok {
Expand Down
6 changes: 3 additions & 3 deletions auth/jwt/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,23 @@ func TestToGRPCContext(t *testing.T) {
reqFunc := ToGRPCContext()

// No Authorization header is passed
ctx := reqFunc(context.Background(), &md)
ctx := reqFunc(context.Background(), md)
token := ctx.Value(JWTTokenContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
}

// Invalid Authorization header is passed
md["authorization"] = []string{fmt.Sprintf("%s", signedKey)}
ctx = reqFunc(context.Background(), &md)
ctx = reqFunc(context.Background(), md)
token = ctx.Value(JWTTokenContextKey)
if token != nil {
t.Error("Context should not contain a JWT Token")
}

// Authorization header is correct
md["authorization"] = []string{fmt.Sprintf("Bearer %s", signedKey)}
ctx = reqFunc(context.Background(), &md)
ctx = reqFunc(context.Background(), md)
token, ok := ctx.Value(JWTTokenContextKey).(string)
if !ok {
t.Fatal("JWT Token not passed to context correctly")
Expand Down
2 changes: 1 addition & 1 deletion examples/addsvc/cmd/addsvc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func main() {
return
}

srv := addsvc.MakeGRPCServer(ctx, endpoints, tracer, logger)
srv := addsvc.MakeGRPCServer(endpoints, tracer, logger)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

s := grpc.NewServer()
pb.RegisterAddServer(s, srv)

Expand Down
4 changes: 1 addition & 3 deletions examples/addsvc/transport_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@ import (
)

// MakeGRPCServer makes a set of endpoints available as a gRPC AddServer.
func MakeGRPCServer(ctx context.Context, endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer {
func MakeGRPCServer(endpoints Endpoints, tracer stdopentracing.Tracer, logger log.Logger) pb.AddServer {
options := []grpctransport.ServerOption{
grpctransport.ServerErrorLogger(logger),
}
return &grpcServer{
sum: grpctransport.NewServer(
ctx,
endpoints.SumEndpoint,
DecodeGRPCSumRequest,
EncodeGRPCSumResponse,
append(options, grpctransport.ServerBefore(opentracing.FromGRPCRequest(tracer, "Sum", logger)))...,
),
concat: grpctransport.NewServer(
ctx,
endpoints.ConcatEndpoint,
DecodeGRPCConcatRequest,
EncodeGRPCConcatResponse,
Expand Down
6 changes: 3 additions & 3 deletions tracing/opentracing/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func ToGRPCRequest(tracer opentracing.Tracer, logger log.Logger) func(ctx contex
// `operationName` accordingly. If no trace could be found in `req`, the Span
// will be a trace root. The Span is incorporated in the returned Context and
// can be retrieved with opentracing.SpanFromContext(ctx).
func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md *metadata.MD) context.Context {
return func(ctx context.Context, md *metadata.MD) context.Context {
func FromGRPCRequest(tracer opentracing.Tracer, operationName string, logger log.Logger) func(ctx context.Context, md metadata.MD) context.Context {
return func(ctx context.Context, md metadata.MD) context.Context {
var span opentracing.Span
wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{md})
wireContext, err := tracer.Extract(opentracing.TextMap, metadataReaderWriter{&md})
if err != nil && err != opentracing.ErrSpanContextNotFound {
logger.Log("err", err)
}
Expand Down
2 changes: 1 addition & 1 deletion tracing/opentracing/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestTraceGRPCRequestRoundtrip(t *testing.T) {

// Use FromGRPCRequest to verify that we can join with the trace given MD.
fromGRPCFunc := kitot.FromGRPCRequest(tracer, "joined", logger)
joinCtx := fromGRPCFunc(afterCtx, &md)
joinCtx := fromGRPCFunc(afterCtx, md)
joinedSpan := opentracing.SpanFromContext(joinCtx).(*mocktracer.MockSpan)

joinedContext := joinedSpan.Context().(mocktracer.MockSpanContext)
Expand Down
50 changes: 50 additions & 0 deletions transport/grpc/_grpc_test/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package test

import (
"context"

"google.golang.org/grpc"

"github.com/go-kit/kit/endpoint"
grpctransport "github.com/go-kit/kit/transport/grpc"
"github.com/go-kit/kit/transport/grpc/_grpc_test/pb"
)

type clientBinding struct {
test endpoint.Endpoint
}

func (c *clientBinding) Test(ctx context.Context, a string, b int64) (context.Context, string, error) {
response, err := c.test(ctx, TestRequest{A: a, B: b})
if err != nil {
return nil, "", err
}
r := response.(*TestResponse)
return r.Ctx, r.V, nil
}

func NewClient(cc *grpc.ClientConn) Service {
return &clientBinding{
test: grpctransport.NewClient(
cc,
"pb.Test",
"Test",
encodeRequest,
decodeResponse,
&pb.TestResponse{},
grpctransport.ClientBefore(
injectCorrelationID,
),
grpctransport.ClientBefore(
displayClientRequestHeaders,
),
grpctransport.ClientAfter(
displayClientResponseHeaders,
displayClientResponseTrailers,
),
grpctransport.ClientAfter(
extractConsumedCorrelationID,
),
).Endpoint(),
}
}
141 changes: 141 additions & 0 deletions transport/grpc/_grpc_test/context_metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package test

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
)

type metaContext string

const (
correlationID metaContext = "correlation-id"
responseHDR metaContext = "my-response-header"
responseTRLR metaContext = "my-response-trailer"
correlationIDTRLR metaContext = "correlation-id-consumed"
)

/* client before functions */

func injectCorrelationID(ctx context.Context, md *metadata.MD) context.Context {
if hdr, ok := ctx.Value(correlationID).(string); ok {
fmt.Printf("\tClient found correlationID %q in context, set metadata header\n", hdr)
(*md)[string(correlationID)] = append((*md)[string(correlationID)], hdr)
}
return ctx
}

func displayClientRequestHeaders(ctx context.Context, md *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tClient >> Request Headers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

/* server before functions */

func extractCorrelationID(ctx context.Context, md metadata.MD) context.Context {
if hdr, ok := md[string(correlationID)]; ok {
cID := hdr[len(hdr)-1]
ctx = context.WithValue(ctx, correlationID, cID)
fmt.Printf("\tServer received correlationID %q in metadata header, set context\n", cID)
}
return ctx
}

func displayServerRequestHeaders(ctx context.Context, md metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tServer << Request Headers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

/* server after functions */

func injectResponseHeader(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
*md = metadata.Join(*md, metadata.Pairs(string(responseHDR), "has-a-value"))
return ctx
}

func displayServerResponseHeaders(ctx context.Context, md *metadata.MD, _ *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tServer >> Response Headers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

func injectResponseTrailer(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
*md = metadata.Join(*md, metadata.Pairs(string(responseTRLR), "has-a-value-too"))
return ctx
}

func injectConsumedCorrelationID(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
if hdr, ok := ctx.Value(correlationID).(string); ok {
fmt.Printf("\tServer found correlationID %q in context, set consumed trailer\n", hdr)
*md = metadata.Join(*md, metadata.Pairs(string(correlationIDTRLR), hdr))
}
return ctx
}

func displayServerResponseTrailers(ctx context.Context, _ *metadata.MD, md *metadata.MD) context.Context {
if len(*md) > 0 {
fmt.Println("\tServer >> Response Trailers:")
for key, val := range *md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

/* client after functions */

func displayClientResponseHeaders(ctx context.Context, md metadata.MD, _ metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tClient << Response Headers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

func displayClientResponseTrailers(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context {
if len(md) > 0 {
fmt.Println("\tClient << Response Trailers:")
for key, val := range md {
fmt.Printf("\t\t%s: %s\n", key, val[len(val)-1])
}
}
return ctx
}

func extractConsumedCorrelationID(ctx context.Context, _ metadata.MD, md metadata.MD) context.Context {
if hdr, ok := md[string(correlationIDTRLR)]; ok {
fmt.Printf("\tClient received consumed correlationID %q in metadata trailer, set context\n", hdr[len(hdr)-1])
ctx = context.WithValue(ctx, correlationIDTRLR, hdr[len(hdr)-1])
}
return ctx
}

/* CorrelationID context handlers */

func SetCorrelationID(ctx context.Context, v string) context.Context {
return context.WithValue(ctx, correlationID, v)
}

func GetConsumedCorrelationID(ctx context.Context) string {
if trlr, ok := ctx.Value(correlationIDTRLR).(string); ok {
return trlr
}
return ""
}
3 changes: 3 additions & 0 deletions transport/grpc/_grpc_test/pb/generate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package pb

//go:generate protoc test.proto --go_out=plugins=grpc:.
Loading