Skip to content

Commit

Permalink
Separate incoming and outgoing metadata in context
Browse files Browse the repository at this point in the history
This will prevent the incoming RPCs' metadata from appearing in outgoing RPCs
unless it is explicitly copied, e.g.:

incomingMD, ok := metadata.FromContext(ctx)
if ok {
  ctx = metadata.NewContext(ctx, incomingMD)
}

Fixes #1148
  • Loading branch information
dfawley committed Apr 7, 2017
1 parent 087f3d6 commit 0c1d39d
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 43 deletions.
13 changes: 8 additions & 5 deletions Documentation/grpc-metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ md := metadata.Pairs(

## Retrieving metadata from context

Metadata can be retrieved from context using `FromContext`:
Metadata can be retrieved from context using `FromIncomingContext`:

```go
func (s *server) SomeRPC(ctx context.Context, in *pb.SomeRequest) (*pb.SomeResponse, err) {
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromIncomingContext(ctx)
// do something with metadata
}
```
Expand All @@ -88,14 +88,17 @@ To send metadata to server, the client can wrap the metadata into a context usin
md := metadata.Pairs("key", "val")

// create a new context with this metadata
ctx := metadata.NewContext(context.Background(), md)
ctx := metadata.NewOutgoingContext(context.Background(), md)

// make unary RPC
response, err := client.SomeRPC(ctx, someRequest)

// or make streaming RPC
stream, err := client.SomeStreamingRPC(ctx)
```

To read this back from the context on the client (e.g. in an interceptor) before the RPC is sent, use `FromOutgoingContext`.

### Receiving metadata

Metadata that a client can receive includes header and trailer.
Expand Down Expand Up @@ -152,7 +155,7 @@ For streaming calls, the server needs to get context from the stream.

```go
func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) {
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromIncomingContext(ctx)
// do something with metadata
}
```
Expand All @@ -161,7 +164,7 @@ func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someRespo

```go
func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error {
md, ok := metadata.FromContext(stream.Context()) // get context from stream
md, ok := metadata.FromIncomingContext(stream.Context()) // get context from stream
// do something with metadata
}
```
Expand Down
2 changes: 1 addition & 1 deletion grpclb/grpclb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ type helloServer struct {
}

func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
}
Expand Down
10 changes: 5 additions & 5 deletions interop/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ func DoPerRPCCreds(tc testpb.TestServiceClient, serviceAccountKeyFile, oauthScop
}
token := GetToken(serviceAccountKeyFile, oauthScope)
kv := map[string]string{"authorization": token.TokenType + " " + token.AccessToken}
ctx := metadata.NewContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}})
ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}})
reply, err := tc.UnaryCall(ctx, req)
if err != nil {
grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err)
Expand All @@ -416,7 +416,7 @@ var (

// DoCancelAfterBegin cancels the RPC after metadata has been sent but before payloads are sent.
func DoCancelAfterBegin(tc testpb.TestServiceClient, args ...grpc.CallOption) {
ctx, cancel := context.WithCancel(metadata.NewContext(context.Background(), testMetadata))
ctx, cancel := context.WithCancel(metadata.NewOutgoingContext(context.Background(), testMetadata))
stream, err := tc.StreamingInputCall(ctx, args...)
if err != nil {
grpclog.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
Expand Down Expand Up @@ -491,7 +491,7 @@ func DoCustomMetadata(tc testpb.TestServiceClient, args ...grpc.CallOption) {
ResponseSize: proto.Int32(int32(1)),
Payload: pl,
}
ctx := metadata.NewContext(context.Background(), customMetadata)
ctx := metadata.NewOutgoingContext(context.Background(), customMetadata)
var header, trailer metadata.MD
args = append(args, grpc.Header(&header), grpc.Trailer(&trailer))
reply, err := tc.UnaryCall(
Expand Down Expand Up @@ -627,7 +627,7 @@ func serverNewPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error)

func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
status := in.GetResponseStatus()
if md, ok := metadata.FromContext(ctx); ok {
if md, ok := metadata.FromIncomingContext(ctx); ok {
if initialMetadata, ok := md[initialMetadataKey]; ok {
header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
grpc.SendHeader(ctx, header)
Expand Down Expand Up @@ -686,7 +686,7 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput
}

func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
if md, ok := metadata.FromContext(stream.Context()); ok {
if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
if initialMetadata, ok := md[initialMetadataKey]; ok {
header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
stream.SendHeader(header)
Expand Down
38 changes: 31 additions & 7 deletions metadata/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,41 @@ func Join(mds ...MD) MD {
return out
}

type mdKey struct{}
type mdIncomingKey struct{}
type mdOutgoingKey struct{}

// NewContext creates a new context with md attached.
// NewContext is a wrapper for NewOutgoingContext(ctx, md). Deprecated.
func NewContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdKey{}, md)
return NewOutgoingContext(ctx, md)
}

// FromContext returns the MD in ctx if it exists.
// The returned md should be immutable, writing to it may cause races.
// Modification should be made to the copies of the returned md.
// NewIncomingContext creates a new context with incoming md attached.
func NewIncomingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdIncomingKey{}, md)
}

// NewOutgoingContext creates a new context with outgoing md attached.
func NewOutgoingContext(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, mdOutgoingKey{}, md)
}

// FromContext is a wrapper for FromIncomingContext(ctx). Deprecated.
func FromContext(ctx context.Context) (md MD, ok bool) {
md, ok = ctx.Value(mdKey{}).(MD)
return FromIncomingContext(ctx)
}

// FromIncomingContext returns the incoming MD in ctx if it exists. The
// returned md should be immutable, writing to it may cause races.
// Modification should be made to the copies of the returned md.
func FromIncomingContext(ctx context.Context) (md MD, ok bool) {
md, ok = ctx.Value(mdIncomingKey{}).(MD)
return
}

// FromOutgoingContext returns the outgoing MD in ctx if it exists. The
// returned md should be immutable, writing to it may cause races.
// Modification should be made to the copies of the returned md.
func FromOutgoingContext(ctx context.Context) (md MD, ok bool) {
md, ok = ctx.Value(mdOutgoingKey{}).(MD)
return
}
8 changes: 4 additions & 4 deletions stats/stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ var (
type testServer struct{}

func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
md, ok := metadata.FromContext(ctx)
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if err := grpc.SendHeader(ctx, md); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
Expand All @@ -93,7 +93,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
}

func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
md, ok := metadata.FromContext(stream.Context())
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
Expand Down Expand Up @@ -237,7 +237,7 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple
} else {
req = &testpb.SimpleRequest{Id: errorID}
}
ctx := metadata.NewContext(context.Background(), testMetadata)
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)

resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast))
return req, resp, err
Expand All @@ -250,7 +250,7 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
err error
)
tc := testpb.NewTestServiceClient(te.clientConn())
stream, err := tc.FullDuplexCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
if err != nil {
return reqs, resps, err
}
Expand Down
Loading

0 comments on commit 0c1d39d

Please sign in to comment.