Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconcile trailers-only and misc error behavior with grpc-go #690

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 45 additions & 95 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ func TestGRPCMissingTrailersError(t *testing.T) {
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok)
assert.Equal(t, connectErr.Code(), connect.CodeInternal)
assert.Equal(t, connectErr.Code(), connect.CodeUnknown)
assert.True(
t,
strings.HasSuffix(connectErr.Message(), "protocol error: no Grpc-Status trailer: unexpected EOF"),
Expand Down Expand Up @@ -1838,7 +1838,9 @@ func TestUnflushableResponseWriter(t *testing.T) {
t.Parallel()
assertIsFlusherErr := func(t *testing.T, err error) {
t.Helper()
assert.NotNil(t, err)
if !assert.NotNil(t, err) {
return
}
assert.Equal(t, connect.CodeOf(err), connect.CodeInternal, assert.Sprintf("got %v", err))
assert.True(
t,
Expand Down Expand Up @@ -1875,8 +1877,9 @@ func TestUnflushableResponseWriter(t *testing.T) {
assertIsFlusherErr(t, err)
return
}
assert.False(t, stream.Receive())
assertIsFlusherErr(t, stream.Err())
if assert.False(t, stream.Receive()) {
assertIsFlusherErr(t, stream.Err())
}
})
}
}
Expand Down Expand Up @@ -2146,6 +2149,21 @@ func TestStreamUnexpectedEOF(t *testing.T) {
},
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc_missing_status",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
// Trailers exist, just no status. So error will be unknown instead of internal.
responseWriter.Header().Set(http.TrailerPrefix+"grpc-message", "foo")
},
expectCode: connect.CodeUnknown,
expectMsg: "unknown: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc-web_missing_end",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
Expand All @@ -2159,6 +2177,29 @@ func TestStreamUnexpectedEOF(t *testing.T) {
},
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc-web_missing_status",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
// Trailers exist, just no status. So error will be unknown instead of internal.
_, err = responseWriter.Write([]byte{128}) // end-stream flag
assert.Nil(t, err)
endStream := "grpc-message: foo\r\n"
var length [4]byte
binary.BigEndian.PutUint32(length[:], uint32(len(endStream)))
_, err = responseWriter.Write(length[:])
assert.Nil(t, err)
_, err = responseWriter.Write([]byte(endStream))
assert.Nil(t, err)
},
expectCode: connect.CodeUnknown,
expectMsg: "unknown: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "connect_partial_payload",
options: []connect.ClientOption{connect.WithProtoJSON()},
Expand Down Expand Up @@ -2442,97 +2483,6 @@ func TestClientDisconnect(t *testing.T) {
})
}

func TestTrailersOnlyErrors(t *testing.T) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests were for the old classification and no longer apply.

For example, they'd check that if we classify a response as trailers-only (due to presence of grpc-status key), then the client complains (in the form of an RPC error with "internal" code) if there is a non-empty body or non-empty HTTP trailers.

But since we now define trailers-only to be a request without body or HTTP trailers, they are moot. In these situations (where a response has a grpc-status key in headers and a non-empty body) we now just ignore the grpc-status key in headers (which is also what grpc-go does).

t.Parallel()

head := [3]byte{}
testcases := []struct {
name string
handler http.HandlerFunc
options []connect.ClientOption
expectCode connect.Code
expectMsg string
}{{
name: "grpc_body_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc")
header.Set("Grpc-Status", "3")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after trailers-only response", len(head)),
}, {
name: "grpc-web_body_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web")
header.Set("Grpc-Status", "3")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after trailers-only response", len(head)),
}, {
name: "grpc_trailers_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc")
header.Set("Grpc-Status", "3")
responseWriter.WriteHeader(http.StatusOK)
responseWriter.(http.Flusher).Flush() //nolint:forcetypeassert
header.Set(http.TrailerPrefix+"Foo", "abc")
},
expectCode: connect.CodeInternal,
expectMsg: "internal: corrupt response from server: gRPC trailers-only response may not contain HTTP trailers",
}, {
name: "grpc-web_trailers_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web")
header.Set("Grpc-Status", "3")
responseWriter.WriteHeader(http.StatusOK)
responseWriter.(http.Flusher).Flush() //nolint:forcetypeassert
header.Set(http.TrailerPrefix+"Foo", "abc")
},
expectCode: connect.CodeInternal,
expectMsg: "internal: corrupt response from server: gRPC trailers-only response may not contain HTTP trailers",
}}
for _, testcase := range testcases {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.HandleFunc("/", func(responseWriter http.ResponseWriter, request *http.Request) {
_, _ = io.Copy(io.Discard, request.Body)
testcase.handler(responseWriter, request)
})
server := memhttptest.NewServer(t, mux)
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL(),
testcase.options...,
)
const upTo = 2
request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo})
request.Header().Set("Test-Case", t.Name())
stream, err := client.CountUp(context.Background(), request)
assert.Nil(t, err)
for i := 0; stream.Receive() && i < upTo; i++ {
assert.Equal(t, stream.Msg().GetNumber(), 42)
}
assert.NotNil(t, stream.Err())
assert.Equal(t, connect.CodeOf(stream.Err()), testcase.expectCode)
assert.Equal(t, stream.Err().Error(), testcase.expectMsg)
})
}
}

// TestBlankImportCodeGeneration tests that services.connect.go is generated with
// blank import statements to services.pb.go so that the service's Descriptor is
// available in the global proto registry.
Expand Down
27 changes: 17 additions & 10 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func (w *envelopeWriter) write(env *envelope) *Error {
type envelopeReader struct {
ctx context.Context //nolint:containedctx
reader io.Reader
bytesRead int64 // detect trailers-only gRPC responses
codec Codec
last envelope
compressionPool *compressionPool
Expand All @@ -241,6 +242,11 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
env := &envelope{Data: buffer}
err := r.Read(env)
switch {
case err == nil && env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil:
return errorf(
CodeInternal,
"protocol error: sent compressed message without compression support",
)
case err == nil &&
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
env.Data.Len() == 0:
Expand All @@ -257,12 +263,6 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

data := env.Data
if data.Len() > 0 && env.IsSet(flagEnvelopeCompressed) {
if r.compressionPool == nil {
return errorf(
CodeInvalidArgument,
"protocol error: sent compressed message without compression support",
)
}
decompressed := r.bufferPool.Get()
defer func() {
if decompressed != dontRelease {
Expand All @@ -277,7 +277,9 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

if env.Flags != 0 && env.Flags != flagEnvelopeCompressed {
// Drain the rest of the stream to ensure there is no extra data.
if numBytes, err := discard(r.reader); err != nil {
numBytes, err := discard(r.reader)
r.bytesRead += numBytes
if err != nil {
err = wrapIfContextError(err)
if connErr, ok := asError(err); ok {
return connErr
Expand Down Expand Up @@ -308,7 +310,9 @@ func (r *envelopeReader) Read(env *envelope) *Error {
prefixes := [5]byte{}
// io.ReadFull reads the number of bytes requested, or returns an error.
// io.EOF will only be returned if no bytes were read.
if _, err := io.ReadFull(r.reader, prefixes[:]); err != nil {
n, err := io.ReadFull(r.reader, prefixes[:])
r.bytesRead += int64(n)
if err != nil {
if errors.Is(err, io.EOF) {
// The stream ended cleanly. That's expected, but we need to propagate an EOF
// to the user so that they know that the stream has ended. We shouldn't
Expand All @@ -328,7 +332,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
}
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
_, err := io.CopyN(io.Discard, r.reader, size)
n, err := io.CopyN(io.Discard, r.reader, size)
r.bytesRead += n
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", r.readMaxBytes, err)
}
Expand All @@ -337,7 +342,9 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// We've read the prefix, so we know how many bytes to expect.
// CopyN will return an error if it doesn't read the requested
// number of bytes.
if readN, err := io.CopyN(env.Data, r.reader, size); err != nil {
readN, err := io.CopyN(env.Data, r.reader, size)
r.bytesRead += readN
if err != nil {
if errors.Is(err, io.EOF) {
// We've gotten fewer bytes than we expected, so the stream has ended
// unexpectedly.
Expand Down
10 changes: 8 additions & 2 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ func DecodeBinaryHeader(data string) ([]byte, error) {
}

func mergeHeaders(into, from http.Header) {
for k, vals := range from {
into[k] = append(into[k], vals...)
for key, vals := range from {
if len(vals) == 0 {
// For response trailers, net/http will pre-populate entries
// with nil values based on the "Trailer" header. But if there
// are no actual values for those keys, we skip them.
continue
}
into[key] = append(into[key], vals...)
}
}

Expand Down
1 change: 0 additions & 1 deletion header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ func TestHeaderMerge(t *testing.T) {
expect := http.Header{
"Foo": []string{"one", "two"},
"Bar": []string{"one"},
"Baz": nil,
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
}
assert.Equal(t, header, expect)
}
18 changes: 18 additions & 0 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,15 @@ func connectValidateUnaryResponseContentType(
)
}
// Normal responses must have valid content-type that indicates same codec as the request.
if !strings.HasPrefix(responseContentType, connectUnaryContentTypePrefix) {
// Doesn't even look like a Connect response? Use code "unknown".
return errorf(
CodeUnknown,
"invalid content-type: %q; expecting %q",
responseContentType,
connectUnaryContentTypePrefix+requestCodecName,
)
}
responseCodecName := connectCodecFromContentType(
StreamTypeUnary,
responseContentType,
Expand All @@ -1410,6 +1419,15 @@ func connectValidateUnaryResponseContentType(

func connectValidateStreamResponseContentType(requestCodecName string, streamType StreamType, responseContentType string) *Error {
// Responses must have valid content-type that indicates same codec as the request.
if !strings.HasPrefix(responseContentType, connectStreamingContentTypePrefix) {
// Doesn't even look like a Connect response? Use code "unknown".
return errorf(
CodeUnknown,
"invalid content-type: %q; expecting %q",
responseContentType,
connectUnaryContentTypePrefix+requestCodecName,
)
}
responseCodecName := connectCodecFromContentType(
streamType,
responseContentType,
Expand Down
37 changes: 24 additions & 13 deletions protocol_connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func TestConnectValidateUnaryResponseContentType(t *testing.T) {
codecName: codecNameJSON,
statusCode: http.StatusOK,
responseContentType: "some/garbage",
expectCode: CodeInternal,
expectCode: CodeUnknown, // doesn't even look like it could be connect protocol
expectBadContentType: true,
},
// Error status, invalid content-type, returns code based on HTTP status code
Expand Down Expand Up @@ -296,7 +296,7 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
testCases := []struct {
codecName string
responseContentType string
expectErr bool
expectCode Code
}{
// Allowed content-types
{
Expand All @@ -307,31 +307,42 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
codecName: codecNameJSON,
responseContentType: "application/connect+json",
},
// Mismatched response codec
{
codecName: codecNameProto,
responseContentType: "application/connect+json",
expectCode: CodeInternal,
},
{
codecName: codecNameJSON,
responseContentType: "application/connect+proto",
expectCode: CodeInternal,
},
// Disallowed content-types
{
codecName: codecNameJSON,
responseContentType: "application/connect+json; charset=utf-8",
expectCode: CodeInternal, // *almost* looks right
},
{
codecName: codecNameProto,
responseContentType: "application/proto",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameJSON,
responseContentType: "application/json",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameJSON,
responseContentType: "application/json; charset=utf-8",
expectErr: true,
},
{
codecName: codecNameJSON,
responseContentType: "application/connect+json; charset=utf-8",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameProto,
responseContentType: "some/garbage",
expectErr: true,
expectCode: CodeUnknown,
},
}
for _, testCase := range testCases {
Expand All @@ -344,10 +355,10 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
StreamTypeServer,
testCase.responseContentType,
)
if !testCase.expectErr {
if testCase.expectCode == 0 {
assert.Nil(t, err)
} else if assert.NotNil(t, err) {
assert.Equal(t, CodeOf(err), CodeInternal)
assert.Equal(t, CodeOf(err), testCase.expectCode)
assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType)))
}
})
Expand Down
Loading
Loading