Skip to content

Commit

Permalink
transport: fix handling of header metadata in serverHandler (#3484)
Browse files Browse the repository at this point in the history
  • Loading branch information
misberner authored Apr 3, 2020
1 parent aedb136 commit 66e9dfe
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 43 deletions.
68 changes: 45 additions & 23 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,10 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats sta
// at this point to be speaking over HTTP/2, so it's able to speak valid
// gRPC.
type serverHandlerTransport struct {
rw http.ResponseWriter
req *http.Request
timeoutSet bool
timeout time.Duration
didCommonHeaders bool
rw http.ResponseWriter
req *http.Request
timeoutSet bool
timeout time.Duration

headerMD metadata.MD

Expand Down Expand Up @@ -186,8 +185,11 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
ht.writeStatusMu.Lock()
defer ht.writeStatusMu.Unlock()

headersWritten := s.updateHeaderSent()
err := ht.do(func() {
ht.writeCommonHeaders(s)
if !headersWritten {
ht.writePendingHeaders(s)
}

// And flush, in case no header or body has been sent yet.
// This forces a separation of headers and trailers if this is the
Expand Down Expand Up @@ -238,14 +240,16 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
return err
}

// writePendingHeaders sets common and custom headers on the first
// write call (Write, WriteHeader, or WriteStatus)
func (ht *serverHandlerTransport) writePendingHeaders(s *Stream) {
ht.writeCommonHeaders(s)
ht.writeCustomHeaders(s)
}

// writeCommonHeaders sets common headers on the first write
// call (Write, WriteHeader, or WriteStatus).
func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
if ht.didCommonHeaders {
return
}
ht.didCommonHeaders = true

h := ht.rw.Header()
h["Date"] = nil // suppress Date to make tests happy; TODO: restore
h.Set("Content-Type", ht.contentType)
Expand All @@ -264,29 +268,47 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
}
}

// writeCustomHeaders sets custom headers set on the stream via SetHeader
// on the first write call (Write, WriteHeader, or WriteStatus).
func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
h := ht.rw.Header()

s.hdrMu.Lock()
for k, vv := range s.header {
if isReservedHeader(k) {
continue
}
for _, v := range vv {
h.Add(k, encodeMetadataHeader(k, v))
}
}

s.hdrMu.Unlock()
}

func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
headersWritten := s.updateHeaderSent()
return ht.do(func() {
ht.writeCommonHeaders(s)
if !headersWritten {
ht.writePendingHeaders(s)
}
ht.rw.Write(hdr)
ht.rw.Write(data)
ht.rw.(http.Flusher).Flush()
})
}

func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
if err := s.SetHeader(md); err != nil {
return err
}

headersWritten := s.updateHeaderSent()
err := ht.do(func() {
ht.writeCommonHeaders(s)
h := ht.rw.Header()
for k, vv := range md {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
v = encodeMetadataHeader(k, v)
h.Add(k, v)
}
if !headersWritten {
ht.writePendingHeaders(s)
}

ht.rw.WriteHeader(200)
ht.rw.(http.Flusher).Flush()
})
Expand Down
102 changes: 82 additions & 20 deletions internal/transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,32 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}

err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value"))
if err != nil {
t.Error(err)
}
err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value"))
if err != nil {
t.Error(err)
}

md := metadata.Pairs("custom-header", "Another custom header value")
err = s.SendHeader(md)
delete(md, "custom-header")
if err != nil {
t.Error(err)
}

err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored"))
if err == nil {
t.Error("expected SetHeader call after SendHeader to fail")
}
err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well"))
if err == nil {
t.Error("expected second SendHeader call to fail")
}

st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
Expand All @@ -277,14 +303,16 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Grpc-Status": {"0"},
"Date": {},
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Custom-Header": {"Custom header value", "Another custom header value"},
}
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader)
wantTrailer := http.Header{
"Grpc-Status": {"0"},
"Custom-Trailer": {"Custom trailer value"},
}
checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
}

// Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
Expand All @@ -308,16 +336,15 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
"Date": {},
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
}
wantTrailer := http.Header{
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {encodeGrpcMessage(msg)},
}

if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
}
checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
}

func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
Expand Down Expand Up @@ -360,15 +387,15 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
"Date": {},
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
}
wantTrailer := http.Header{
"Grpc-Status": {"4"},
"Grpc-Message": {encodeGrpcMessage("too slow")},
}
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
}
checkHeaderAndTrailer(t, rw, wantHeader, wantTrailer)
}

// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
Expand Down Expand Up @@ -447,15 +474,50 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Date": {},
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
}
wantTrailer := http.Header{
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {encodeGrpcMessage(msg)},
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
}

if !reflect.DeepEqual(hst.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", hst.rw.HeaderMap, wantHeader)
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
}

// checkHeaderAndTrailer checks that the resulting header and trailer matches the expectation.
func checkHeaderAndTrailer(t *testing.T, rw testHandlerResponseWriter, wantHeader, wantTrailer http.Header) {
// For trailer-only responses, the trailer values might be reported as part of the Header. They will however
// be present in Trailer in either case. Hence, normalize the header by removing all trailer values.
actualHeader := cloneHeader(rw.Result().Header)
for _, trailerKey := range actualHeader["Trailer"] {
actualHeader.Del(trailerKey)
}

if !reflect.DeepEqual(actualHeader, wantHeader) {
t.Errorf("Header mismatch.\n got: %#v\n want: %#v", actualHeader, wantHeader)
}
if actualTrailer := rw.Result().Trailer; !reflect.DeepEqual(actualTrailer, wantTrailer) {
t.Errorf("Trailer mismatch.\n got: %#v\n want: %#v", actualTrailer, wantTrailer)
}
}

// cloneHeader performs a deep clone of an http.Header, since the (http.Header).Clone() method was only added in
// Go 1.13.
func cloneHeader(hdr http.Header) http.Header {
if hdr == nil {
return nil
}

hdrClone := make(http.Header, len(hdr))

for k, vv := range hdr {
vvClone := make([]string, len(vv))
copy(vvClone, vv)
hdrClone[k] = vvClone
}

return hdrClone
}

0 comments on commit 66e9dfe

Please sign in to comment.