From 4c43e4b8e19f4caa2c105ed65f50627a2edbbadd Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 18 Sep 2023 14:19:18 +0100 Subject: [PATCH 01/21] Debug flaky testcases with in memory network Changes httptest.Server to use an in memory network. This makes testing more robust to ephemeral port issues under load. Two flaky test cases were now easily reproducible and fixed. Both occured from races between starting the network request in a go routine and checking for a http2 stream and erroring if not. This can error on Send or Receive depending on how fast the request can complete. See: https://github.com/golang/go/issues/14200 --- bench_test.go | 12 +- client_example_test.go | 7 +- client_ext_test.go | 13 +-- client_get_fallback_test.go | 7 +- compression_test.go | 6 +- connect_ext_test.go | 206 +++++++++++------------------------ example_init_test.go | 117 +++----------------- handler_ext_test.go | 10 +- interceptor_example_test.go | 4 +- interceptor_ext_test.go | 13 +-- internal/connecttest/http.go | 156 ++++++++++++++++++++++++++ recover_ext_test.go | 7 +- 12 files changed, 258 insertions(+), 300 deletions(-) create mode 100644 internal/connecttest/http.go diff --git a/bench_test.go b/bench_test.go index 9ecc3914..632299ea 100644 --- a/bench_test.go +++ b/bench_test.go @@ -21,12 +21,12 @@ import ( "encoding/json" "io" "net/http" - "net/http/httptest" "strings" "testing" connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) @@ -38,10 +38,7 @@ func BenchmarkConnect(b *testing.B) { &ExamplePingServer{}, ), ) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - b.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(b, mux) httpClient := server.Client() httpTransport, ok := httpClient.Transport.(*http.Transport) @@ -113,10 +110,7 @@ func BenchmarkREST(b *testing.B) { assert.Nil(b, err) } - server := httptest.NewUnstartedServer(http.HandlerFunc(handler)) - server.EnableHTTP2 = true - server.StartTLS() - b.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(b, http.HandlerFunc(handler)) twoMiB := strings.Repeat("a", 2*1024*1024) b.ResetTimer() diff --git a/client_example_test.go b/client_example_test.go index 56359017..d4e38f23 100644 --- a/client_example_test.go +++ b/client_example_test.go @@ -27,9 +27,8 @@ import ( func Example_client() { logger := log.New(os.Stdout, "" /* prefix */, 0 /* flags */) - // Unfortunately, pkg.go.dev can't run examples that actually use the - // network. To keep this example runnable, we'll use an HTTP server and - // client that communicate over in-memory pipes. The client is still a plain + // To keep this example runnable, we'll use an HTTP server and client + // that communicate over in-memory pipes. The client is still a plain // *http.Client! var httpClient *http.Client = examplePingServer.Client() @@ -37,7 +36,7 @@ func Example_client() { // connect.WithGRPCWeb() to switch protocols. client := pingv1connect.NewPingServiceClient( httpClient, - examplePingServer.URL(), + examplePingServer.URL, ) response, err := client.Ping( context.Background(), diff --git a/client_ext_test.go b/client_ext_test.go index cb4dede4..70693ecc 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -18,12 +18,12 @@ import ( "context" "errors" "net/http" - "net/http/httptest" "strings" "testing" connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) @@ -75,10 +75,7 @@ func TestClientPeer(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) { t.Helper() @@ -157,11 +154,7 @@ func TestGetNotModified(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(¬ModifiedPingServer{etag: etag})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) - + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, diff --git a/client_get_fallback_test.go b/client_get_fallback_test.go index a8076ffd..998353a9 100644 --- a/client_get_fallback_test.go +++ b/client_get_fallback_test.go @@ -17,11 +17,11 @@ package connect import ( "context" "net/http" - "net/http/httptest" "strings" "testing" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" ) @@ -38,10 +38,7 @@ func TestClientUnaryGetFallback(t *testing.T) { }, WithIdempotency(IdempotencyNoSideEffects), )) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) client := NewClient[pingv1.PingRequest, pingv1.PingResponse]( server.Client(), diff --git a/compression_test.go b/compression_test.go index 5ae53e6b..dd77e2b5 100644 --- a/compression_test.go +++ b/compression_test.go @@ -17,10 +17,10 @@ package connect import ( "context" "net/http" - "net/http/httptest" "testing" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/connecttest" "google.golang.org/protobuf/types/known/emptypb" ) @@ -42,9 +42,7 @@ func TestAcceptEncodingOrdering(t *testing.T) { w.WriteHeader(http.StatusOK) called = true }) - server := httptest.NewServer(verify) - t.Cleanup(server.Close) - + server := connecttest.StartHTTPTestServer(t, verify) client := NewClient[emptypb.Empty, emptypb.Empty]( server.Client(), server.URL, diff --git a/connect_ext_test.go b/connect_ext_test.go index 9e75253e..9f38c455 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -34,6 +34,7 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/connecttest" "connectrpc.com/connect/internal/gen/connect/import/v1/importv1connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" @@ -167,8 +168,10 @@ func TestServer(t *testing.T) { assert.Equal(t, got, expect) }) t.Run("count_up_error", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) stream, err := client.CountUp( - context.Background(), + ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1}), ) assert.Nil(t, err) @@ -286,6 +289,11 @@ func TestServer(t *testing.T) { t.Run("cumsum_cancel_before_send", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) stream := client.CumSum(ctx) + if !expectSuccess { // server doesn't support HTTP/2 + failNoHTTP2(t, stream) + cancel() + return + } stream.RequestHeader().Set(clientHeader, headerValue) assert.Nil(t, stream.Send(&pingv1.CumSumRequest{Number: 8})) cancel() @@ -428,16 +436,12 @@ func TestServer(t *testing.T) { t.Run("http1", func(t *testing.T) { t.Parallel() - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) testMatrix(t, server, false /* bidi */) }) t.Run("http2", func(t *testing.T) { t.Parallel() - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) testMatrix(t, server, true /* bidi */) }) } @@ -449,10 +453,7 @@ func TestConcurrentStreams(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) var done, start sync.WaitGroup start.Add(1) for i := 0; i < 100; i++ { @@ -510,8 +511,7 @@ func TestHeaderBasic(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) request := connect.NewRequest(&pingv1.PingRequest{}) @@ -540,10 +540,7 @@ func TestHeaderHost(t *testing.T) { t.Helper() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) return server } @@ -594,8 +591,7 @@ func TestTimeoutParsing(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -607,8 +603,7 @@ func TestTimeoutParsing(t *testing.T) { func TestFailCodec(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - server := httptest.NewServer(handler) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, handler) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, @@ -625,8 +620,7 @@ func TestFailCodec(t *testing.T) { func TestContextError(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - server := httptest.NewServer(handler) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, handler) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, @@ -650,10 +644,7 @@ func TestGRPCMarshalStatusError(t *testing.T) { pingServer{}, connect.WithCodec(failCodec{}), )) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) { tb.Helper() @@ -692,10 +683,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { mux.Handle(pingv1connect.NewPingServiceHandler( pingServer{checkMetadata: true}, )) - server := httptest.NewUnstartedServer(trimTrailers(mux)) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, trimTrailers(mux)) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) assertErrorNoTrailers := func(t *testing.T, err error) { @@ -778,16 +766,18 @@ func TestBidiRequiresHTTP2(t *testing.T) { _, err := io.WriteString(w, "hello world") assert.Nil(t, err) }) - server := httptest.NewServer(handler) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, handler) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, ) stream := client.CumSum(context.Background()) - assert.Nil(t, stream.Send(&pingv1.CumSumRequest{})) - assert.Nil(t, stream.CloseRequest()) - _, err := stream.Receive() + // Stream creates an async request, can error on Send or Receive. + err := stream.Send(&pingv1.CumSumRequest{}) + if err == nil { + assert.Nil(t, stream.CloseRequest()) + _, err = stream.Receive() + } assert.NotNil(t, err) var connectErr *connect.Error assert.True(t, errors.As(err, &connectErr)) @@ -806,8 +796,7 @@ func TestCompressMinBytesClient(t *testing.T) { mux.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) })) - server := httptest.NewServer(mux) - tb.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) _, err := pingv1connect.NewPingServiceClient( server.Client(), server.URL, @@ -842,10 +831,7 @@ func TestCompressMinBytes(t *testing.T) { pingServer{}, connect.WithCompressMinBytes(8), )) - server := httptest.NewServer(mux) - t.Cleanup(func() { - server.Close() - }) + server := connecttest.StartHTTPTestServer(t, mux) client := server.Client() getPingResponse := func(t *testing.T, pingText string) *http.Response { @@ -899,9 +885,7 @@ func TestCustomCompression(t *testing.T) { pingServer{}, connect.WithCompression(compressionName, decompressor, compressor), )) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - + server := connecttest.StartHTTPTestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithAcceptCompression(compressionName, decompressor, compressor), @@ -920,9 +904,7 @@ func TestClientWithoutGzipSupport(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - + server := connecttest.StartHTTPTestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithAcceptCompression("gzip", nil, nil), @@ -939,10 +921,7 @@ func TestInvalidHeaderTimeout(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(func() { - server.Close() - }) + server := connecttest.StartHTTPTestServer(t, mux) getPingResponseWithTimeout := func(t *testing.T, timeout string) *http.Response { t.Helper() request, err := http.NewRequestWithContext( @@ -975,8 +954,7 @@ func TestInterceptorReturnsWrongType(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { if _, err := next(ctx, request); err != nil { @@ -1054,10 +1032,7 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { } newHTTP2Server := func(t *testing.T) *httptest.Server { t.Helper() - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) return server } t.Run("connect", func(t *testing.T) { @@ -1141,47 +1116,39 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) }) } - newHTTP2Server := func(t *testing.T) *httptest.Server { - t.Helper() - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) - return server - } t.Run("connect", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) run(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendGzip()) run(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) run(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) run(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - server := newHTTP2Server(t) + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb(), connect.WithSendGzip()) run(t, client, true) }) @@ -1199,10 +1166,7 @@ func TestClientWithReadMaxBytes(t *testing.T) { compressionOption = connect.WithCompressMinBytes(math.MaxInt) } mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, compressionOption)) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - tb.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) return server } serverUncompressed := createServer(t, false) @@ -1347,10 +1311,7 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { pingServer{}, options..., )) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) return server } t.Run("connect", func(t *testing.T) { @@ -1395,10 +1356,7 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, sendMaxBytes int, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { @@ -1498,11 +1456,7 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) - + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, @@ -1540,9 +1494,7 @@ func TestStreamForServer(t *testing.T) { newPingServer := func(pingServer pingv1connect.PingServiceHandler) (pingv1connect.PingServiceClient, *httptest.Server) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, @@ -1551,12 +1503,11 @@ func TestStreamForServer(t *testing.T) { } t.Run("not-proto-message", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { return stream.Conn().Send("foobar") }, }) - t.Cleanup(server.Close) stream := client.CumSum(context.Background()) assert.Nil(t, stream.Send(nil)) _, err := stream.Receive() @@ -1566,12 +1517,11 @@ func TestStreamForServer(t *testing.T) { }) t.Run("nil-message", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { return stream.Send(nil) }, }) - t.Cleanup(server.Close) stream := client.CumSum(context.Background()) assert.Nil(t, stream.Send(nil)) _, err := stream.Receive() @@ -1581,7 +1531,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("get-spec", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) @@ -1589,14 +1539,13 @@ func TestStreamForServer(t *testing.T) { return nil }, }) - t.Cleanup(server.Close) stream := client.CumSum(context.Background()) assert.Nil(t, stream.Send(nil)) assert.Nil(t, stream.CloseRequest()) }) t.Run("server-stream", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Equal(t, stream.Conn().Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, stream.Conn().Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) @@ -1605,7 +1554,6 @@ func TestStreamForServer(t *testing.T) { return nil }, }) - t.Cleanup(server.Close) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) assert.Nil(t, err) assert.NotNil(t, stream) @@ -1613,13 +1561,12 @@ func TestStreamForServer(t *testing.T) { }) t.Run("server-stream-send", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1})) return nil }, }) - t.Cleanup(server.Close) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) assert.Nil(t, err) assert.True(t, stream.Receive()) @@ -1630,7 +1577,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("server-stream-send-nil", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { stream.ResponseHeader().Set("foo", "bar") stream.ResponseTrailer().Set("bas", "blah") @@ -1638,7 +1585,6 @@ func TestStreamForServer(t *testing.T) { return nil }, }) - t.Cleanup(server.Close) stream, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{})) assert.Nil(t, err) assert.False(t, stream.Receive()) @@ -1652,7 +1598,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) @@ -1664,7 +1610,6 @@ func TestStreamForServer(t *testing.T) { return connect.NewResponse(&pingv1.SumResponse{Sum: 1}), nil }, }) - t.Cleanup(server.Close) stream := client.Sum(context.Background()) assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) res, err := stream.CloseAndReceive() @@ -1674,13 +1619,12 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream-conn", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.NotNil(t, stream.Conn().Send("not-proto")) return connect.NewResponse(&pingv1.SumResponse{}), nil }, }) - t.Cleanup(server.Close) stream := client.Sum(context.Background()) assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) res, err := stream.CloseAndReceive() @@ -1689,13 +1633,12 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream-send-msg", func(t *testing.T) { t.Parallel() - client, server := newPingServer(&pluggablePingServer{ + client, _ := newPingServer(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2})) return connect.NewResponse(&pingv1.SumResponse{}), nil }, }) - t.Cleanup(server.Close) stream := client.Sum(context.Background()) assert.Nil(t, stream.Send(&pingv1.SumRequest{Number: 1})) res, err := stream.CloseAndReceive() @@ -1716,8 +1659,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { }, } mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, @@ -1821,10 +1763,7 @@ func TestFailCompression(t *testing.T) { connect.WithCompression(compressorName, decompressor, compressor), ), ) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) pingclient := pingv1connect.NewPingServiceClient( server.Client(), server.URL, @@ -1859,10 +1798,7 @@ func TestUnflushableResponseWriter(t *testing.T) { handler.ServeHTTP(&unflushableWriter{w}, r) }) mux.Handle(path, wrapped) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) tests := []struct { name string @@ -1895,10 +1831,7 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) protoBytes, err := proto.Marshal(&pingv1.FailRequest{Code: int32(connect.CodeInternal)}) assert.Nil(t, err) @@ -1935,10 +1868,7 @@ func TestConnectProtocolHeaderSentByDefault(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithRequireConnectProtocolHeader())) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) @@ -1959,8 +1889,7 @@ func TestConnectProtocolHeaderRequired(t *testing.T) { pingServer{}, connect.WithRequireConnectProtocolHeader(), )) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) tests := []struct { headers http.Header @@ -1999,8 +1928,7 @@ func TestAllowCustomUserAgent(t *testing.T) { return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil }, })) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) // If the user has set a User-Agent, we shouldn't clobber it. tests := []struct { @@ -2036,8 +1964,7 @@ func TestWebXUserAgent(t *testing.T) { return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil }, })) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) @@ -2049,8 +1976,7 @@ func TestBidiOverHTTP1(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) // Clients expecting a full-duplex connection that end up with a simplex // HTTP/1.1 connection shouldn't hang. Instead, the server should close the @@ -2095,8 +2021,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { return nil, nil //nolint: nilnil }, }, connect.WithRecover(recoverPanic))) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) @@ -2126,10 +2051,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, _ = io.Copy(io.Discard, request.Body) testcase(responseWriter, request) }) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) head := [5]byte{} payload := []byte(`{"number": 42}`) diff --git a/example_init_test.go b/example_init_test.go index ab7e1563..3747ff06 100644 --- a/example_init_test.go +++ b/example_init_test.go @@ -15,130 +15,41 @@ package connect_test import ( - "context" - "errors" - "net" "net/http" "net/http/httptest" - "sync" + "connectrpc.com/connect/internal/connecttest" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) -var examplePingServer *inMemoryServer +var examplePingServer *httptest.Server func init() { - // Generally, init functions are bad. + // Generally, init functions are bad. However, we need to set up the server + // before the examples run. // // To write testable examples that users can grok *and* can execute in the - // playground, where networking is disabled, we need an HTTP server that uses - // in-memory pipes instead of TCP. We don't want to pollute every example - // with this setup code. - // - // The least-awful option is to set up the server in init(). + // playground we use an in memory pipe as network based playgrounds can + // deadlock, see: + // (https://github.com/golang/go/issues/48394) mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) examplePingServer = newInMemoryServer(mux) } -// inMemoryServer is an HTTP server that uses in-memory pipes instead of TCP. -// It supports HTTP/2 and has TLS enabled. -// -// The Go Playground panics if we try to start a TCP-backed server. If you're -// not familiar with the Playground's behavior, it looks like our examples are -// broken. This server lets us write examples that work in the playground -// without abstracting over HTTP. -type inMemoryServer struct { - server *httptest.Server - listener *memoryListener -} - // newInMemoryServer constructs and starts an inMemoryServer. -func newInMemoryServer(handler http.Handler) *inMemoryServer { - lis := &memoryListener{ - conns: make(chan net.Conn), - closed: make(chan struct{}), - } +func newInMemoryServer(handler http.Handler) *httptest.Server { + lis := connecttest.NewMemoryListener() server := httptest.NewUnstartedServer(handler) server.Listener = lis server.EnableHTTP2 = true server.StartTLS() - return &inMemoryServer{ - server: server, - listener: lis, - } -} - -// Client returns an HTTP client configured to trust the server's TLS -// certificate and use HTTP/2 over an in-memory pipe. Automatic HTTP-level gzip -// compression is disabled. It closes its idle connections when the server is -// closed. -func (s *inMemoryServer) Client() *http.Client { - client := s.server.Client() + client := server.Client() + // Configure the httptest.Server client to use the in-memory listener. + // Automatic HTTP-level gzip compression is disabled. if transport, ok := client.Transport.(*http.Transport); ok { - transport.DialContext = s.listener.DialContext + transport.DialContext = lis.DialContext transport.DisableCompression = true } - return client -} - -// URL is the server's URL. -func (s *inMemoryServer) URL() string { - return s.server.URL -} - -// Close shuts down the server, blocking until all outstanding requests have -// completed. -func (s *inMemoryServer) Close() { - s.server.Close() -} - -type memoryListener struct { - conns chan net.Conn - once sync.Once - closed chan struct{} + return server } - -// Accept implements net.Listener. -func (l *memoryListener) Accept() (net.Conn, error) { - select { - case conn := <-l.conns: - return conn, nil - case <-l.closed: - return nil, errors.New("listener closed") - } -} - -// Close implements net.Listener. -func (l *memoryListener) Close() error { - l.once.Do(func() { - close(l.closed) - }) - return nil -} - -// Addr implements net.Listener. -func (l *memoryListener) Addr() net.Addr { - return &memoryAddr{} -} - -// DialContext is the type expected by http.Transport.DialContext. -func (l *memoryListener) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - select { - case <-l.closed: - return nil, errors.New("listener closed") - default: - } - server, client := net.Pipe() - l.conns <- server - return client, nil -} - -type memoryAddr struct{} - -// Network implements net.Addr. -func (*memoryAddr) Network() string { return "memory" } - -// String implements io.Stringer, returning a value that matches the -// certificates used by net/http/httptest. -func (*memoryAddr) String() string { return "example.com" } diff --git a/handler_ext_test.go b/handler_ext_test.go index 4aeb78f0..2afed6be 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -21,13 +21,13 @@ import ( "encoding/json" "io" "net/http" - "net/http/httptest" "strings" "sync" "testing" connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) @@ -42,11 +42,8 @@ func TestHandler_ServeHTTP(t *testing.T) { mux.Handle("/prefixed/", http.StripPrefix("/prefixed", prefixed)) const pingProcedure = pingv1connect.PingServicePingProcedure const sumProcedure = pingv1connect.PingServiceSumProcedure - server := httptest.NewServer(mux) + server := connecttest.StartHTTPTestServer(t, mux) client := server.Client() - t.Cleanup(func() { - server.Close() - }) t.Run("get_method_no_encoding", func(t *testing.T) { t.Parallel() @@ -217,8 +214,7 @@ func TestHandlerMaliciousPrefix(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(successPingServer{})) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) const ( concurrency = 256 diff --git a/interceptor_example_test.go b/interceptor_example_test.go index 7d7a2542..4c4742b0 100644 --- a/interceptor_example_test.go +++ b/interceptor_example_test.go @@ -43,7 +43,7 @@ func ExampleUnaryInterceptorFunc() { ) client := pingv1connect.NewPingServiceClient( examplePingServer.Client(), - examplePingServer.URL(), + examplePingServer.URL, connect.WithInterceptors(loggingInterceptor), ) if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 42})); err != nil { @@ -81,7 +81,7 @@ func ExampleWithInterceptors() { ) client := pingv1connect.NewPingServiceClient( examplePingServer.Client(), - examplePingServer.URL(), + examplePingServer.URL, connect.WithInterceptors(outer, inner), ) if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})); err != nil { diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 1e87ac82..ff4ea961 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -18,12 +18,12 @@ import ( "context" "fmt" "net/http" - "net/http/httptest" "sync/atomic" "testing" connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) @@ -127,9 +127,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { handlerOnion, ), ) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - + server := connecttest.StartHTTPTestServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, @@ -174,8 +172,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { } }) mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + server := connecttest.StartHTTPTestServer(t, mux) connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithInterceptors(interceptor)) _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) @@ -204,9 +201,7 @@ func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { connect.WithInterceptors(handlerChecker), ), ) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) - + server := connecttest.StartHTTPTestServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, diff --git a/internal/connecttest/http.go b/internal/connecttest/http.go new file mode 100644 index 00000000..c9168c93 --- /dev/null +++ b/internal/connecttest/http.go @@ -0,0 +1,156 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connecttest + +import ( + "context" + "errors" + "io" + "log" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +// StartHTTPTestServer starts an HTTP server that listens on a in memory +// network. The returned server is configured to use the in memory network. +func StartHTTPTestServer(t testing.TB, handler http.Handler) *httptest.Server { + lis := NewMemoryListener() + svr := httptest.NewUnstartedServer(handler) + svr.Config.ErrorLog = log.New(NewTestWriter(t), "", 0) + svr.Listener = lis + svr.Start() + t.Cleanup(svr.Close) + client := svr.Client() + if transport, ok := client.Transport.(*http.Transport); ok { + transport.DialContext = lis.DialContext + } else { + t.Fatalf("unexpected transport type: %T", client.Transport) + } + return svr +} + +// StartHTTP2TestServer starts an HTTP/2 server that listens on a in memory +// network. The returned server is configured to use the in memory network and +// TLS. +func StartHTTP2TestServer(t testing.TB, handler http.Handler) *httptest.Server { + lis := NewMemoryListener() + svr := httptest.NewUnstartedServer(handler) + svr.Config.ErrorLog = log.New(NewTestWriter(t), "", 0) + svr.Listener = lis + svr.EnableHTTP2 = true + svr.StartTLS() + t.Cleanup(svr.Close) + client := svr.Client() + if transport, ok := client.Transport.(*http.Transport); ok { + transport.DialContext = lis.DialContext + } else { + t.Fatalf("unexpected transport type: %T", client.Transport) + } + return svr +} + +type testWriter struct { + tb testing.TB +} + +func (l *testWriter) Write(p []byte) (n int, err error) { + l.tb.Log(string(p)) + return +} + +// NewTestWriter returns a writer that logs to the given testing.TB. +func NewTestWriter(tb testing.TB) io.Writer { + return &testWriter{tb} +} + +// MemoryListener is a net.Listener that listens on an in memory network. +type MemoryListener struct { + conns chan chan net.Conn + once sync.Once + closed chan struct{} +} + +func NewMemoryListener() *MemoryListener { + return &MemoryListener{ + conns: make(chan chan net.Conn), + closed: make(chan struct{}), + } +} + +// Accept implements net.Listener. +func (l *MemoryListener) Accept() (net.Conn, error) { + aerr := func(err error) error { + return &net.OpError{ + Op: "accept", + Net: memoryAddr{}.Network(), + Addr: memoryAddr{}, + Err: err, + } + } + select { + case <-l.closed: + return nil, aerr(errors.New("listener closed")) + case accept := <-l.conns: + server, client := net.Pipe() + accept <- client + return server, nil + } +} + +// Close implements net.Listener. +func (l *MemoryListener) Close() error { + l.once.Do(func() { + close(l.closed) + }) + return nil +} + +// Addr implements net.Listener. +func (l *MemoryListener) Addr() net.Addr { + return &memoryAddr{} +} + +// DialContext is the type expected by http.Transport.DialContext. +func (l *MemoryListener) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + derr := func(err error) error { + return &net.OpError{ + Op: "dial", + Net: memoryAddr{}.Network(), + Err: err, + } + } + + accepted := make(chan net.Conn) + select { + case <-ctx.Done(): + return nil, derr(ctx.Err()) + case l.conns <- accepted: + return <-accepted, nil + case <-l.closed: + return nil, derr(errors.New("listener closed")) + } +} + +type memoryAddr struct{} + +// Network implements net.Addr. +func (memoryAddr) Network() string { return "memory" } + +// String implements io.Stringer, returning a value that matches the +// certificates used by net/http/httptest. +func (memoryAddr) String() string { return "example.com" } diff --git a/recover_ext_test.go b/recover_ext_test.go index e8cb991b..95ec36a2 100644 --- a/recover_ext_test.go +++ b/recover_ext_test.go @@ -18,11 +18,11 @@ import ( "context" "fmt" "net/http" - "net/http/httptest" "testing" connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) @@ -77,10 +77,7 @@ func TestWithRecover(t *testing.T) { pinger := &panicPingServer{} mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pinger, connect.WithRecover(handle))) - server := httptest.NewUnstartedServer(mux) - server.EnableHTTP2 = true - server.StartTLS() - t.Cleanup(server.Close) + server := connecttest.StartHTTP2TestServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL, From 1f1df5df3756894bd011084842a46fad638a9433 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 18 Sep 2023 14:32:33 +0100 Subject: [PATCH 02/21] Fix lint --- internal/connecttest/http.go | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/internal/connecttest/http.go b/internal/connecttest/http.go index c9168c93..6614756b 100644 --- a/internal/connecttest/http.go +++ b/internal/connecttest/http.go @@ -28,18 +28,19 @@ import ( // StartHTTPTestServer starts an HTTP server that listens on a in memory // network. The returned server is configured to use the in memory network. -func StartHTTPTestServer(t testing.TB, handler http.Handler) *httptest.Server { +func StartHTTPTestServer(tb testing.TB, handler http.Handler) *httptest.Server { + tb.Helper() lis := NewMemoryListener() svr := httptest.NewUnstartedServer(handler) - svr.Config.ErrorLog = log.New(NewTestWriter(t), "", 0) + svr.Config.ErrorLog = log.New(NewTestWriter(tb), "", 0) //nolint:forbidigo svr.Listener = lis svr.Start() - t.Cleanup(svr.Close) + tb.Cleanup(svr.Close) client := svr.Client() if transport, ok := client.Transport.(*http.Transport); ok { transport.DialContext = lis.DialContext } else { - t.Fatalf("unexpected transport type: %T", client.Transport) + tb.Fatalf("unexpected transport type: %T", client.Transport) } return svr } @@ -47,19 +48,20 @@ func StartHTTPTestServer(t testing.TB, handler http.Handler) *httptest.Server { // StartHTTP2TestServer starts an HTTP/2 server that listens on a in memory // network. The returned server is configured to use the in memory network and // TLS. -func StartHTTP2TestServer(t testing.TB, handler http.Handler) *httptest.Server { +func StartHTTP2TestServer(tb testing.TB, handler http.Handler) *httptest.Server { + tb.Helper() lis := NewMemoryListener() svr := httptest.NewUnstartedServer(handler) - svr.Config.ErrorLog = log.New(NewTestWriter(t), "", 0) + svr.Config.ErrorLog = log.New(NewTestWriter(tb), "", 0) //nolint:forbidigo svr.Listener = lis svr.EnableHTTP2 = true svr.StartTLS() - t.Cleanup(svr.Close) + tb.Cleanup(svr.Close) client := svr.Client() if transport, ok := client.Transport.(*http.Transport); ok { transport.DialContext = lis.DialContext } else { - t.Fatalf("unexpected transport type: %T", client.Transport) + tb.Fatalf("unexpected transport type: %T", client.Transport) } return svr } @@ -68,13 +70,14 @@ type testWriter struct { tb testing.TB } -func (l *testWriter) Write(p []byte) (n int, err error) { +func (l *testWriter) Write(p []byte) (int, error) { l.tb.Log(string(p)) - return + return len(p), nil } // NewTestWriter returns a writer that logs to the given testing.TB. func NewTestWriter(tb testing.TB) io.Writer { + tb.Helper() return &testWriter{tb} } From 225aa0ddd0f0dfb53e7ee6efa8a8e7ee9b3ab4c2 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 18 Sep 2023 17:03:46 +0100 Subject: [PATCH 03/21] Add TODO for flaky test --- client_ext_test.go | 2 ++ connect_ext_test.go | 22 +++++++++++----------- internal/connecttest/http.go | 1 - 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 70693ecc..a8566cd7 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -108,6 +108,8 @@ func TestClientPeer(t *testing.T) { // server streaming serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) t.Cleanup(func() { + // TODO(emcfarlane): debug flaky test close with error: + // "unknown: io: read/write on closed pipe" assert.Nil(t, serverStream.Close()) }) assert.Nil(t, err) diff --git a/connect_ext_test.go b/connect_ext_test.go index 9f38c455..e36b78c5 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1491,7 +1491,7 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { func TestStreamForServer(t *testing.T) { t.Parallel() - newPingServer := func(pingServer pingv1connect.PingServiceHandler) (pingv1connect.PingServiceClient, *httptest.Server) { + newPingClient := func(pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) server := connecttest.StartHTTP2TestServer(t, mux) @@ -1499,11 +1499,11 @@ func TestStreamForServer(t *testing.T) { server.Client(), server.URL, ) - return client, server + return client } t.Run("not-proto-message", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { return stream.Conn().Send("foobar") }, @@ -1517,7 +1517,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("nil-message", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { return stream.Send(nil) }, @@ -1531,7 +1531,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("get-spec", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi) assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure) @@ -1545,7 +1545,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("server-stream", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Equal(t, stream.Conn().Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, stream.Conn().Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) @@ -1561,7 +1561,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("server-stream-send", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1})) return nil @@ -1577,7 +1577,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("server-stream-send-nil", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { stream.ResponseHeader().Set("foo", "bar") stream.ResponseTrailer().Set("bas", "blah") @@ -1598,7 +1598,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) @@ -1619,7 +1619,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream-conn", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.NotNil(t, stream.Conn().Send("not-proto")) return connect.NewResponse(&pingv1.SumResponse{}), nil @@ -1633,7 +1633,7 @@ func TestStreamForServer(t *testing.T) { }) t.Run("client-stream-send-msg", func(t *testing.T) { t.Parallel() - client, _ := newPingServer(&pluggablePingServer{ + client := newPingClient(&pluggablePingServer{ sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2})) return connect.NewResponse(&pingv1.SumResponse{}), nil diff --git a/internal/connecttest/http.go b/internal/connecttest/http.go index 6614756b..5df24fd8 100644 --- a/internal/connecttest/http.go +++ b/internal/connecttest/http.go @@ -137,7 +137,6 @@ func (l *MemoryListener) DialContext(ctx context.Context, network, addr string) Err: err, } } - accepted := make(chan net.Conn) select { case <-ctx.Done(): From 27410bdbb5aca07d6bceda32a0283101b7280be6 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 18 Sep 2023 23:28:46 +0100 Subject: [PATCH 04/21] Use local networking for benchmarks --- bench_test.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/bench_test.go b/bench_test.go index 632299ea..9ecc3914 100644 --- a/bench_test.go +++ b/bench_test.go @@ -21,12 +21,12 @@ import ( "encoding/json" "io" "net/http" + "net/http/httptest" "strings" "testing" connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) @@ -38,7 +38,10 @@ func BenchmarkConnect(b *testing.B) { &ExamplePingServer{}, ), ) - server := connecttest.StartHTTP2TestServer(b, mux) + server := httptest.NewUnstartedServer(mux) + server.EnableHTTP2 = true + server.StartTLS() + b.Cleanup(server.Close) httpClient := server.Client() httpTransport, ok := httpClient.Transport.(*http.Transport) @@ -110,7 +113,10 @@ func BenchmarkREST(b *testing.B) { assert.Nil(b, err) } - server := connecttest.StartHTTP2TestServer(b, http.HandlerFunc(handler)) + server := httptest.NewUnstartedServer(http.HandlerFunc(handler)) + server.EnableHTTP2 = true + server.StartTLS() + b.Cleanup(server.Close) twoMiB := strings.Repeat("a", 2*1024*1024) b.ResetTimer() From 877b4df64da84b12cf0dee710364387dd41dfc41 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 6 Oct 2023 16:40:49 +0100 Subject: [PATCH 05/21] Create memhttp and memhttp test packages --- .golangci.yml | 9 + client_example_test.go | 2 +- client_ext_test.go | 10 +- client_get_fallback_test.go | 6 +- compression_test.go | 7 +- connect_ext_test.go | 218 +++++++++--------- example_init_test.go | 24 +- go.mod | 5 + go.sum | 4 + handler_ext_test.go | 27 +-- interceptor_example_test.go | 4 +- interceptor_ext_test.go | 15 +- .../http.go => memhttp/listener.go} | 77 +------ internal/memhttp/memhttp.go | 144 ++++++++++++ internal/memhttp/memhttp_test.go | 160 +++++++++++++ internal/memhttp/memhttptest/http.go | 56 +++++ internal/memhttp/option.go | 67 ++++++ recover_ext_test.go | 6 +- 18 files changed, 605 insertions(+), 236 deletions(-) rename internal/{connecttest/http.go => memhttp/listener.go} (52%) create mode 100644 internal/memhttp/memhttp.go create mode 100644 internal/memhttp/memhttp_test.go create mode 100644 internal/memhttp/memhttptest/http.go create mode 100644 internal/memhttp/option.go diff --git a/.golangci.yml b/.golangci.yml index 2598fb2b..50ff54e5 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -119,3 +119,12 @@ issues: - linters: [revive] text: "^if-return: " path: error_writer.go + # We want to set http.Server's logger + - linters: [forbidigo] + path: internal/memhttp + # We want to set http.Server's logger + - linters: [forbidigo] + path: internal/memhttp/memhttptest + # We want to show examples with http.Get + - linters: [noctx, bodyclose] + path: internal/memhttp/memhttp_test.go diff --git a/client_example_test.go b/client_example_test.go index d4e38f23..c85e8d44 100644 --- a/client_example_test.go +++ b/client_example_test.go @@ -36,7 +36,7 @@ func Example_client() { // connect.WithGRPCWeb() to switch protocols. client := pingv1connect.NewPingServiceClient( httpClient, - examplePingServer.URL, + examplePingServer.URL(), ) response, err := client.Ping( context.Background(), diff --git a/client_ext_test.go b/client_ext_test.go index a8566cd7..36908a1d 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -23,9 +23,9 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) func TestNewClient_InitFailure(t *testing.T) { @@ -75,13 +75,13 @@ func TestClientPeer(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) { t.Helper() client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithClientOptions(opts...), connect.WithInterceptors(&assertPeerInterceptor{t}), ) @@ -156,10 +156,10 @@ func TestGetNotModified(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(¬ModifiedPingServer{etag: etag})) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithHTTPGet(), ) ctx := context.Background() diff --git a/client_get_fallback_test.go b/client_get_fallback_test.go index 998353a9..c9444ef6 100644 --- a/client_get_fallback_test.go +++ b/client_get_fallback_test.go @@ -21,8 +21,8 @@ import ( "testing" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) func TestClientUnaryGetFallback(t *testing.T) { @@ -38,11 +38,11 @@ func TestClientUnaryGetFallback(t *testing.T) { }, WithIdempotency(IdempotencyNoSideEffects), )) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) client := NewClient[pingv1.PingRequest, pingv1.PingResponse]( server.Client(), - server.URL+"/connect.ping.v1.PingService/Ping", + server.URL()+"/connect.ping.v1.PingService/Ping", WithHTTPGet(), WithHTTPGetMaxURLSize(1, true), WithSendGzip(), diff --git a/compression_test.go b/compression_test.go index dd77e2b5..0dc73466 100644 --- a/compression_test.go +++ b/compression_test.go @@ -20,7 +20,8 @@ import ( "testing" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/connecttest" + "connectrpc.com/connect/internal/memhttp" + "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/types/known/emptypb" ) @@ -42,10 +43,10 @@ func TestAcceptEncodingOrdering(t *testing.T) { w.WriteHeader(http.StatusOK) called = true }) - server := connecttest.StartHTTPTestServer(t, verify) + server := memhttptest.NewServer(t, verify, memhttp.WithoutHTTP2()) client := NewClient[emptypb.Empty, emptypb.Empty]( server.Client(), - server.URL, + server.URL(), withFakeBrotli, withGzip(), ) diff --git a/connect_ext_test.go b/connect_ext_test.go index e36b78c5..dcf9ebcd 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -26,7 +26,6 @@ import ( "math" "math/rand" "net/http" - "net/http/httptest" "strings" "sync" "testing" @@ -34,10 +33,11 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/connecttest" "connectrpc.com/connect/internal/gen/connect/import/v1/importv1connect" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" + "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoregistry" ) @@ -352,10 +352,10 @@ func TestServer(t *testing.T) { assertIsHTTPMiddlewareError(t, stream.Err()) }) } - testMatrix := func(t *testing.T, server *httptest.Server, bidi bool) { //nolint:thelper + testMatrix := func(t *testing.T, server *memhttp.Server, bidi bool) { //nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), opts...) testPing(t, client) testSum(t, client) testCountUp(t, client) @@ -436,12 +436,12 @@ func TestServer(t *testing.T) { t.Run("http1", func(t *testing.T) { t.Parallel() - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) testMatrix(t, server, false /* bidi */) }) t.Run("http2", func(t *testing.T) { t.Parallel() - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) testMatrix(t, server, true /* bidi */) }) } @@ -453,14 +453,14 @@ func TestConcurrentStreams(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) var done, start sync.WaitGroup start.Add(1) for i := 0; i < 100; i++ { done.Add(1) go func() { defer done.Done() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) var total int64 sum := client.CumSum(context.Background()) start.Wait() @@ -511,9 +511,9 @@ func TestHeaderBasic(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) request := connect.NewRequest(&pingv1.PingRequest{}) request.Header().Set(key, cval) response, err := client.Ping(context.Background(), request) @@ -536,11 +536,11 @@ func TestHeaderHost(t *testing.T) { }, } - newHTTP2Server := func(t *testing.T) *httptest.Server { + newHTTP2Server := func(t *testing.T) *memhttp.Server { t.Helper() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) return server } @@ -557,21 +557,21 @@ func TestHeaderHost(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) callWithHost(t, client) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) callWithHost(t, client) }) t.Run("grpc-web", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) callWithHost(t, client) }) } @@ -591,11 +591,11 @@ func TestTimeoutParsing(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) } @@ -603,10 +603,10 @@ func TestTimeoutParsing(t *testing.T) { func TestFailCodec(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - server := connecttest.StartHTTPTestServer(t, handler) + server := memhttptest.NewServer(t, handler, memhttp.WithoutHTTP2()) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithCodec(failCodec{}), ) stream := client.CumSum(context.Background()) @@ -620,10 +620,10 @@ func TestFailCodec(t *testing.T) { func TestContextError(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - server := connecttest.StartHTTPTestServer(t, handler) + server := memhttptest.NewServer(t, handler) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), ) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -644,11 +644,11 @@ func TestGRPCMarshalStatusError(t *testing.T) { pingServer{}, connect.WithCodec(failCodec{}), )) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) assertInternalError := func(tb testing.TB, opts ...connect.ClientOption) { tb.Helper() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), opts...) request := connect.NewRequest(&pingv1.FailRequest{Code: int32(connect.CodeResourceExhausted)}) _, err := client.Fail(context.Background(), request) tb.Log(err) @@ -683,8 +683,8 @@ func TestGRPCMissingTrailersError(t *testing.T) { mux.Handle(pingv1connect.NewPingServiceHandler( pingServer{checkMetadata: true}, )) - server := connecttest.StartHTTP2TestServer(t, trimTrailers(mux)) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + server := memhttptest.NewServer(t, trimTrailers(mux)) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) assertErrorNoTrailers := func(t *testing.T, err error) { t.Helper() @@ -766,10 +766,10 @@ func TestBidiRequiresHTTP2(t *testing.T) { _, err := io.WriteString(w, "hello world") assert.Nil(t, err) }) - server := connecttest.StartHTTPTestServer(t, handler) + server := memhttptest.NewServer(t, handler, memhttp.WithoutHTTP2()) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), ) stream := client.CumSum(context.Background()) // Stream creates an async request, can error on Send or Receive. @@ -796,10 +796,10 @@ func TestCompressMinBytesClient(t *testing.T) { mux.Handle("/", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { assert.Equal(tb, request.Header.Get("Content-Encoding"), expect) })) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux) _, err := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithSendGzip(), connect.WithCompressMinBytes(8), ).Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Text: text})) @@ -831,7 +831,7 @@ func TestCompressMinBytes(t *testing.T) { pingServer{}, connect.WithCompressMinBytes(8), )) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) client := server.Client() getPingResponse := func(t *testing.T, pingText string) *http.Response { @@ -842,7 +842,7 @@ func TestCompressMinBytes(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + server.URL()+"/"+pingv1connect.PingServiceName+"/Ping", bytes.NewReader(requestBytes), ) assert.Nil(t, err) @@ -885,9 +885,9 @@ func TestCustomCompression(t *testing.T) { pingServer{}, connect.WithCompression(compressionName, decompressor, compressor), )) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) client := pingv1connect.NewPingServiceClient(server.Client(), - server.URL, + server.URL(), connect.WithAcceptCompression(compressionName, decompressor, compressor), connect.WithSendCompression(compressionName), ) @@ -904,9 +904,9 @@ func TestClientWithoutGzipSupport(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) client := pingv1connect.NewPingServiceClient(server.Client(), - server.URL, + server.URL(), connect.WithAcceptCompression("gzip", nil, nil), connect.WithSendGzip(), ) @@ -921,13 +921,13 @@ func TestInvalidHeaderTimeout(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux) getPingResponseWithTimeout := func(t *testing.T, timeout string) *http.Response { t.Helper() request, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + server.URL()+"/"+pingv1connect.PingServiceName+"/Ping", strings.NewReader("{}"), ) assert.Nil(t, err) @@ -954,8 +954,8 @@ func TestInterceptorReturnsWrongType(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := connecttest.StartHTTPTestServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { if _, err := next(ctx, request); err != nil { return nil, err @@ -1030,45 +1030,45 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes)) }) } - newHTTP2Server := func(t *testing.T) *httptest.Server { + newHTTP2Server := func(t *testing.T) *memhttp.Server { t.Helper() - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) return server } t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) readMaxBytesMatrix(t, client, true) }) } @@ -1118,45 +1118,45 @@ func TestHandlerWithHTTPMaxBytes(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - server := connecttest.StartHTTP2TestServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) run(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - server := connecttest.StartHTTP2TestServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendGzip()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - server := connecttest.StartHTTP2TestServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) run(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - server := connecttest.StartHTTP2TestServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC(), connect.WithSendGzip()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC(), connect.WithSendGzip()) run(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - server := connecttest.StartHTTP2TestServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) run(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - server := connecttest.StartHTTP2TestServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb(), connect.WithSendGzip()) + server := memhttptest.NewServer(t, mux) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb(), connect.WithSendGzip()) run(t, client, true) }) } func TestClientWithReadMaxBytes(t *testing.T) { t.Parallel() - createServer := func(tb testing.TB, enableCompression bool) *httptest.Server { + createServer := func(tb testing.TB, enableCompression bool) *memhttp.Server { tb.Helper() mux := http.NewServeMux() var compressionOption connect.HandlerOption @@ -1166,7 +1166,7 @@ func TestClientWithReadMaxBytes(t *testing.T) { compressionOption = connect.WithCompressMinBytes(math.MaxInt) } mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, compressionOption)) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) return server } serverUncompressed := createServer(t, false) @@ -1214,32 +1214,32 @@ func TestClientWithReadMaxBytes(t *testing.T) { } t.Run("connect", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL, connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL, connect.WithReadMaxBytes(readMaxBytes)) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes)) readMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL, connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL, connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPC()) readMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL, connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(serverUncompressed.Client(), serverUncompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() - client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL, connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(serverCompressed.Client(), serverCompressed.URL(), connect.WithReadMaxBytes(readMaxBytes), connect.WithGRPCWeb()) readMaxBytesMatrix(t, client, true) }) } @@ -1298,7 +1298,7 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { } }) } - newHTTP2Server := func(t *testing.T, compressed bool, sendMaxBytes int) *httptest.Server { + newHTTP2Server := func(t *testing.T, compressed bool, sendMaxBytes int) *memhttp.Server { t.Helper() mux := http.NewServeMux() options := []connect.HandlerOption{connect.WithSendMaxBytes(sendMaxBytes)} @@ -1311,43 +1311,43 @@ func TestHandlerWithSendMaxBytes(t *testing.T) { pingServer{}, options..., )) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) return server } t.Run("connect", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPC()) sendMaxBytesMatrix(t, client, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, false, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() server := newHTTP2Server(t, true, sendMaxBytes) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, true) }) } @@ -1356,7 +1356,7 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) sendMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, sendMaxBytes int, compressed bool) { t.Helper() t.Run("equal_send_max", func(t *testing.T) { @@ -1408,37 +1408,37 @@ func TestClientWithSendMaxBytes(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes)) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes)) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("connect_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpc", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpc_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPC(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb()) sendMaxBytesMatrix(t, client, sendMaxBytes, false) }) t.Run("grpcweb_gzip", func(t *testing.T) { t.Parallel() sendMaxBytes := 1024 - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithSendMaxBytes(sendMaxBytes), connect.WithGRPCWeb(), connect.WithSendGzip()) sendMaxBytesMatrix(t, client, sendMaxBytes, true) }) } @@ -1456,10 +1456,10 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithClientOptions(opts...), connect.WithInterceptors(&assertPeerInterceptor{t}), ) @@ -1494,10 +1494,10 @@ func TestStreamForServer(t *testing.T) { newPingClient := func(pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), ) return client } @@ -1659,11 +1659,11 @@ func TestConnectHTTPErrorCodes(t *testing.T) { }, } mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + server.URL()+"/"+pingv1connect.PingServiceName+"/Ping", strings.NewReader("{}"), ) assert.Nil(t, err) @@ -1672,7 +1672,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { assert.Nil(t, err) defer resp.Body.Close() assert.Equal(t, wantHttpStatus, resp.StatusCode) - connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) connectResp, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) assert.Nil(t, connectResp) @@ -1763,10 +1763,10 @@ func TestFailCompression(t *testing.T) { connect.WithCompression(compressorName, decompressor, compressor), ), ) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) pingclient := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithAcceptCompression(compressorName, decompressor, compressor), connect.WithSendCompression(compressorName), ) @@ -1798,7 +1798,7 @@ func TestUnflushableResponseWriter(t *testing.T) { handler.ServeHTTP(&unflushableWriter{w}, r) }) mux.Handle(path, wrapped) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) tests := []struct { name string @@ -1812,7 +1812,7 @@ func TestUnflushableResponseWriter(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, tt.options...) + pingclient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), tt.options...) stream, err := pingclient.CountUp( context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 5}), @@ -1831,7 +1831,7 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) protoBytes, err := proto.Marshal(&pingv1.FailRequest{Code: int32(connect.CodeInternal)}) assert.Nil(t, err) @@ -1844,7 +1844,7 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingv1connect.PingServiceFailProcedure, + server.URL()+pingv1connect.PingServiceFailProcedure, bytes.NewReader(body), ) assert.Nil(t, err) @@ -1868,9 +1868,9 @@ func TestConnectProtocolHeaderSentByDefault(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithRequireConnectProtocolHeader())) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) @@ -1889,7 +1889,7 @@ func TestConnectProtocolHeaderRequired(t *testing.T) { pingServer{}, connect.WithRequireConnectProtocolHeader(), )) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) tests := []struct { headers http.Header @@ -1901,7 +1901,7 @@ func TestConnectProtocolHeaderRequired(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+"/"+pingv1connect.PingServiceName+"/Ping", + server.URL()+"/"+pingv1connect.PingServiceName+"/Ping", strings.NewReader("{}"), ) assert.Nil(t, err) @@ -1928,7 +1928,7 @@ func TestAllowCustomUserAgent(t *testing.T) { return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil }, })) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) // If the user has set a User-Agent, we shouldn't clobber it. tests := []struct { @@ -1940,7 +1940,7 @@ func TestAllowCustomUserAgent(t *testing.T) { {"grpcweb", []connect.ClientOption{connect.WithGRPCWeb()}}, } for _, testCase := range tests { - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, testCase.opts...) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), testCase.opts...) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) req.Header().Set("User-Agent", customAgent) _, err := client.Ping(context.Background(), req) @@ -1964,9 +1964,9 @@ func TestWebXUserAgent(t *testing.T) { return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil }, })) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithGRPCWeb()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) _, err := client.Ping(context.Background(), req) assert.Nil(t, err) @@ -1976,12 +1976,12 @@ func TestBidiOverHTTP1(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) // Clients expecting a full-duplex connection that end up with a simplex // HTTP/1.1 connection shouldn't hang. Instead, the server should close the // TCP connection. - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) stream := client.CumSum(context.Background()) if err := stream.Send(&pingv1.CumSumRequest{Number: 2}); err != nil { assert.ErrorIs(t, err, io.EOF) @@ -2021,8 +2021,8 @@ func TestHandlerReturnsNilResponse(t *testing.T) { return nil, nil //nolint: nilnil }, }, connect.WithRecover(recoverPanic))) - server := connecttest.StartHTTPTestServer(t, mux) - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.NotNil(t, err) @@ -2051,7 +2051,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, _ = io.Copy(io.Discard, request.Body) testcase(responseWriter, request) }) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) head := [5]byte{} payload := []byte(`{"number": 42}`) @@ -2231,7 +2231,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { t.Parallel() client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), testcase.options..., ) const upTo = 2 diff --git a/example_init_test.go b/example_init_test.go index 3747ff06..9edf5114 100644 --- a/example_init_test.go +++ b/example_init_test.go @@ -16,13 +16,12 @@ package connect_test import ( "net/http" - "net/http/httptest" - "connectrpc.com/connect/internal/connecttest" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" ) -var examplePingServer *httptest.Server +var examplePingServer *memhttp.Server func init() { // Generally, init functions are bad. However, we need to set up the server @@ -34,22 +33,5 @@ func init() { // (https://github.com/golang/go/issues/48394) mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - examplePingServer = newInMemoryServer(mux) -} - -// newInMemoryServer constructs and starts an inMemoryServer. -func newInMemoryServer(handler http.Handler) *httptest.Server { - lis := connecttest.NewMemoryListener() - server := httptest.NewUnstartedServer(handler) - server.Listener = lis - server.EnableHTTP2 = true - server.StartTLS() - client := server.Client() - // Configure the httptest.Server client to use the in-memory listener. - // Automatic HTTP-level gzip compression is disabled. - if transport, ok := client.Transport.(*http.Transport); ok { - transport.DialContext = lis.DialContext - transport.DisableCompression = true - } - return server + examplePingServer = memhttp.NewServer(mux) } diff --git a/go.mod b/go.mod index 0bb3ca93..a1da35ba 100644 --- a/go.mod +++ b/go.mod @@ -11,3 +11,8 @@ require ( github.com/google/go-cmp v0.5.9 google.golang.org/protobuf v1.31.0 ) + +require ( + golang.org/x/net v0.16.0 // indirect + golang.org/x/text v0.13.0 // indirect +) diff --git a/go.sum b/go.sum index 4d0bc04e..8d2bee48 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,10 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos= +golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= diff --git a/handler_ext_test.go b/handler_ext_test.go index 2afed6be..d7cf7517 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -27,9 +27,10 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) func TestHandler_ServeHTTP(t *testing.T) { @@ -42,7 +43,7 @@ func TestHandler_ServeHTTP(t *testing.T) { mux.Handle("/prefixed/", http.StripPrefix("/prefixed", prefixed)) const pingProcedure = pingv1connect.PingServicePingProcedure const sumProcedure = pingv1connect.PingServiceSumProcedure - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) client := server.Client() t.Run("get_method_no_encoding", func(t *testing.T) { @@ -50,7 +51,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader(""), ) assert.Nil(t, err) @@ -65,7 +66,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+pingProcedure+`?encoding=unk&message={}`, + server.URL()+pingProcedure+`?encoding=unk&message={}`, strings.NewReader(""), ) assert.Nil(t, err) @@ -80,7 +81,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+pingProcedure+`?encoding=json&message={}`, + server.URL()+pingProcedure+`?encoding=json&message={}`, strings.NewReader(""), ) assert.Nil(t, err) @@ -95,7 +96,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+"/prefixed"+pingProcedure+`?encoding=json&message={}`, + server.URL()+"/prefixed"+pingProcedure+`?encoding=json&message={}`, strings.NewReader(""), ) assert.Nil(t, err) @@ -110,7 +111,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodGet, - server.URL+sumProcedure, + server.URL()+sumProcedure, strings.NewReader(""), ) assert.Nil(t, err) @@ -126,7 +127,7 @@ func TestHandler_ServeHTTP(t *testing.T) { request, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader("{}"), ) assert.Nil(t, err) @@ -155,7 +156,7 @@ func TestHandler_ServeHTTP(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader("{}"), ) assert.Nil(t, err) @@ -171,7 +172,7 @@ func TestHandler_ServeHTTP(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader("{}"), ) assert.Nil(t, err) @@ -187,7 +188,7 @@ func TestHandler_ServeHTTP(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingProcedure, + server.URL()+pingProcedure, strings.NewReader("{}"), ) assert.Nil(t, err) @@ -214,7 +215,7 @@ func TestHandlerMaliciousPrefix(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(successPingServer{})) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) const ( concurrency = 256 @@ -230,7 +231,7 @@ func TestHandlerMaliciousPrefix(t *testing.T) { req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, - server.URL+pingv1connect.PingServicePingProcedure, + server.URL()+pingv1connect.PingServicePingProcedure, bytes.NewReader(body), ) assert.Nil(t, err) diff --git a/interceptor_example_test.go b/interceptor_example_test.go index 4c4742b0..7d7a2542 100644 --- a/interceptor_example_test.go +++ b/interceptor_example_test.go @@ -43,7 +43,7 @@ func ExampleUnaryInterceptorFunc() { ) client := pingv1connect.NewPingServiceClient( examplePingServer.Client(), - examplePingServer.URL, + examplePingServer.URL(), connect.WithInterceptors(loggingInterceptor), ) if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 42})); err != nil { @@ -81,7 +81,7 @@ func ExampleWithInterceptors() { ) client := pingv1connect.NewPingServiceClient( examplePingServer.Client(), - examplePingServer.URL, + examplePingServer.URL(), connect.WithInterceptors(outer, inner), ) if _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})); err != nil { diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index ff4ea961..8025ac45 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -23,9 +23,10 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) func TestOnionOrderingEndToEnd(t *testing.T) { @@ -127,10 +128,10 @@ func TestOnionOrderingEndToEnd(t *testing.T) { handlerOnion, ), ) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), clientOnion, ) @@ -172,8 +173,8 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { } }) mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) - server := connecttest.StartHTTPTestServer(t, mux) - connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL, connect.WithInterceptors(interceptor)) + server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(interceptor)) _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) sumStream := connectClient.Sum(context.Background()) @@ -201,10 +202,10 @@ func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { connect.WithInterceptors(handlerChecker), ), ) - server := connecttest.StartHTTPTestServer(t, mux) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), connect.WithInterceptors(clientChecker), ) diff --git a/internal/connecttest/http.go b/internal/memhttp/listener.go similarity index 52% rename from internal/connecttest/http.go rename to internal/memhttp/listener.go index 5df24fd8..773dd545 100644 --- a/internal/connecttest/http.go +++ b/internal/memhttp/listener.go @@ -12,85 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -package connecttest +package memhttp import ( "context" "errors" - "io" - "log" "net" - "net/http" - "net/http/httptest" "sync" - "testing" ) -// StartHTTPTestServer starts an HTTP server that listens on a in memory -// network. The returned server is configured to use the in memory network. -func StartHTTPTestServer(tb testing.TB, handler http.Handler) *httptest.Server { - tb.Helper() - lis := NewMemoryListener() - svr := httptest.NewUnstartedServer(handler) - svr.Config.ErrorLog = log.New(NewTestWriter(tb), "", 0) //nolint:forbidigo - svr.Listener = lis - svr.Start() - tb.Cleanup(svr.Close) - client := svr.Client() - if transport, ok := client.Transport.(*http.Transport); ok { - transport.DialContext = lis.DialContext - } else { - tb.Fatalf("unexpected transport type: %T", client.Transport) - } - return svr -} - -// StartHTTP2TestServer starts an HTTP/2 server that listens on a in memory -// network. The returned server is configured to use the in memory network and -// TLS. -func StartHTTP2TestServer(tb testing.TB, handler http.Handler) *httptest.Server { - tb.Helper() - lis := NewMemoryListener() - svr := httptest.NewUnstartedServer(handler) - svr.Config.ErrorLog = log.New(NewTestWriter(tb), "", 0) //nolint:forbidigo - svr.Listener = lis - svr.EnableHTTP2 = true - svr.StartTLS() - tb.Cleanup(svr.Close) - client := svr.Client() - if transport, ok := client.Transport.(*http.Transport); ok { - transport.DialContext = lis.DialContext - } else { - tb.Fatalf("unexpected transport type: %T", client.Transport) - } - return svr -} - -type testWriter struct { - tb testing.TB -} - -func (l *testWriter) Write(p []byte) (int, error) { - l.tb.Log(string(p)) - return len(p), nil -} - -// NewTestWriter returns a writer that logs to the given testing.TB. -func NewTestWriter(tb testing.TB) io.Writer { - tb.Helper() - return &testWriter{tb} -} - // MemoryListener is a net.Listener that listens on an in memory network. type MemoryListener struct { - conns chan chan net.Conn + conns chan net.Conn once sync.Once closed chan struct{} } +// NewMemoryListener returns a new in-memory listener. func NewMemoryListener() *MemoryListener { return &MemoryListener{ - conns: make(chan chan net.Conn), + conns: make(chan net.Conn), closed: make(chan struct{}), } } @@ -108,9 +49,7 @@ func (l *MemoryListener) Accept() (net.Conn, error) { select { case <-l.closed: return nil, aerr(errors.New("listener closed")) - case accept := <-l.conns: - server, client := net.Pipe() - accept <- client + case server := <-l.conns: return server, nil } } @@ -137,12 +76,12 @@ func (l *MemoryListener) DialContext(ctx context.Context, network, addr string) Err: err, } } - accepted := make(chan net.Conn) + server, client := net.Pipe() select { case <-ctx.Done(): return nil, derr(ctx.Err()) - case l.conns <- accepted: - return <-accepted, nil + case l.conns <- server: + return client, nil case <-l.closed: return nil, derr(errors.New("listener closed")) } diff --git a/internal/memhttp/memhttp.go b/internal/memhttp/memhttp.go new file mode 100644 index 00000000..2dd6b570 --- /dev/null +++ b/internal/memhttp/memhttp.go @@ -0,0 +1,144 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttp + +import ( + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "sync" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +// Server is a net/http server that uses in-memory pipes instead of TCP. By +// default, it supports http/2 via h2c. It otherwise uses the same configuration +// as the zero value of [http.Server]. +type Server struct { + server http.Server + listener *MemoryListener + url string + cleanupTimeout time.Duration + disableHTTP2 bool + + serverWG sync.WaitGroup + serverErr error +} + +// NewServer creates a new Server that uses the given handler. Configuration +// options may be provided via [Option]s. +func NewServer(handler http.Handler, opts ...Option) *Server { + var cfg config + WithCleanupTimeout(5 * time.Second).apply(&cfg) + for _, opt := range opts { + opt.apply(&cfg) + } + + if !cfg.DisableHTTP2 { + h2s := &http2.Server{} + handler = h2c.NewHandler(handler, h2s) + } + listener := NewMemoryListener() + server := &Server{ + server: http.Server{ + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + }, + listener: listener, + cleanupTimeout: cfg.CleanupTimeout, + url: "http://" + listener.Addr().String(), + disableHTTP2: cfg.DisableHTTP2, + } + server.goServe() + return server +} + +// Transport returns an [http.Transport] configured to use in-memory pipes +// rather than TCP, disable automatic compression, trust the server's TLS +// certificate (if any), and use HTTP/2 (if the server supports it). +// +// Callers may reconfigure the returned Transport without affecting other +// transports or clients. +func (s *Server) Transport() http.RoundTripper { + if s.disableHTTP2 { + return &http.Transport{ + DialContext: s.listener.DialContext, + } + } + return &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + return s.listener.DialContext(ctx, network, addr) + }, + AllowHTTP: true, + } +} + +// Client returns an [http.Client] configured to use in-memory pipes rather +// than TCP, disable automatic compression, trust the server's TLS certificate +// (if any), and use HTTP/2 (if the server supports it). +// +// Callers may reconfigure the returned client without affecting other clients. +func (s *Server) Client() *http.Client { + return &http.Client{Transport: s.Transport()} +} + +// URL returns the server's URL. +func (s *Server) URL() string { + return s.url +} + +// Shutdown gracefully shuts down the server, without interrupting any active +// connections. See [http.Server.Shutdown] for details. +func (s *Server) Shutdown(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, s.cleanupTimeout) + defer cancel() + if err := s.server.Shutdown(ctx); err != nil { + return err + } + return s.Wait() +} + +// Close closes the server's listener. It does not wait for connections to +// finish. +func (s *Server) Close() error { + return s.server.Close() +} + +// RegisterOnShutdown registers a function to call on Shutdown. See +// [http.Server.RegisterOnShutdown] for details. +func (s *Server) RegisterOnShutdown(f func()) { + s.server.RegisterOnShutdown(f) +} + +// Wait blocks until the server exits, then returns its error. +func (s *Server) Wait() error { + s.serverWG.Wait() + if !errors.Is(s.serverErr, http.ErrServerClosed) { + return s.serverErr + } + return nil +} + +func (s *Server) goServe() { + s.serverWG.Add(1) + go func() { + defer s.serverWG.Done() + s.serverErr = s.server.Serve(s.listener) + }() +} diff --git a/internal/memhttp/memhttp_test.go b/internal/memhttp/memhttp_test.go new file mode 100644 index 00000000..07e2b675 --- /dev/null +++ b/internal/memhttp/memhttp_test.go @@ -0,0 +1,160 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttp_test + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" + "time" + + "connectrpc.com/connect/internal/assert" + "connectrpc.com/connect/internal/memhttp" + "connectrpc.com/connect/internal/memhttp/memhttptest" +) + +func TestServer(t *testing.T) { + t.Parallel() + tests := []struct { + name string + opts []memhttp.Option + handler func(t *testing.T, w http.ResponseWriter, r *http.Request) + }{{ + name: "http2", + opts: nil, + handler: func(t *testing.T, _ http.ResponseWriter, r *http.Request) { + t.Helper() + assert.Equal(t, r.ProtoMajor, 2) + assert.Equal(t, r.ProtoMinor, 0) + }, + }, { + name: "http1", + opts: []memhttp.Option{memhttp.WithoutHTTP2()}, + handler: func(t *testing.T, _ http.ResponseWriter, r *http.Request) { + t.Helper() + assert.Equal(t, r.ProtoMajor, 1) + assert.Equal(t, r.ProtoMinor, 1) + }, + }} + for _, testcase := range tests { + testcase := testcase + t.Run(testcase.name, func(t *testing.T) { + t.Parallel() + const concurrency = 100 + const greeting = "Hello, world!" + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testcase.handler(t, w, r) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(greeting)) + }) + server := memhttptest.NewServer(t, handler, testcase.opts...) + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + client := server.Client() + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodGet, + server.URL(), + strings.NewReader(""), + ) + assert.Nil(t, err) + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, res.StatusCode, http.StatusOK) + body, err := io.ReadAll(res.Body) + assert.Nil(t, err) + assert.Nil(t, res.Body.Close()) + assert.Equal(t, string(body), greeting) + }() + } + wg.Wait() + }) + } +} + +func TestRegisterOnShutdown(t *testing.T) { + t.Parallel() + okay := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + server := memhttp.NewServer(okay) + done := make(chan struct{}) + server.RegisterOnShutdown(func() { + close(done) + }) + assert.Nil(t, server.Shutdown(context.Background())) + select { + case <-done: + case <-time.After(5 * time.Second): + t.Error("OnShutdown hook didn't fire") + } +} + +func Example() { + hello := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "Hello, world!") + }) + srv := memhttp.NewServer(hello) + defer srv.Close() + res, err := srv.Client().Get(srv.URL()) + if err != nil { + panic(err) + } + fmt.Println(res.Status) + // Output: + // 200 OK +} + +func ExampleServer_Client() { + hello := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "Hello, world!") + }) + srv := memhttp.NewServer(hello) + defer srv.Close() + client := srv.Client() + client.Timeout = 10 * time.Second + res, err := client.Get(srv.URL()) + if err != nil { + panic(err) + } + fmt.Println(res.Status) + // Output: + // 200 OK +} + +func ExampleServer_Shutdown() { + hello := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "Hello, world!") + }) + srv := memhttp.NewServer(hello) + srv.RegisterOnShutdown(func() { + fmt.Println("Server is shutting down") + }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + panic(err) + } + // Output: + // Server is shutting down +} diff --git a/internal/memhttp/memhttptest/http.go b/internal/memhttp/memhttptest/http.go new file mode 100644 index 00000000..c0c68966 --- /dev/null +++ b/internal/memhttp/memhttptest/http.go @@ -0,0 +1,56 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttptest + +import ( + "context" + "log" + "net/http" + "testing" + + "connectrpc.com/connect/internal/memhttp" +) + +// NewServer constructs a [memhttp.Server] with defaults suitable for tests: +// it logs runtime errors to the provided testing.TB, and it automatically shuts +// down the server when the test completes. Startup and shutdown errors fail the +// test. +// +// To customize the server, use any [memhttp.Option]. In particular, it may be +// necessary to customize the shutdown timeout with +// [memhttp.WithCleanupTimeout]. +func NewServer(tb testing.TB, handler http.Handler, opts ...memhttp.Option) *memhttp.Server { + tb.Helper() + logger := log.New(&testWriter{tb}, "" /* prefix */, log.Lshortfile) + opts = append(opts, memhttp.WithErrorLog(logger)) + server := memhttp.NewServer(handler, opts...) + tb.Cleanup(func() { + tb.Logf("shutting down server") + if err := server.Shutdown(context.Background()); err != nil { + tb.Error(err) + } + }) + return server +} + +// testWriter is an io.Writer that logs to the testing.TB. +type testWriter struct { + tb testing.TB +} + +func (l *testWriter) Write(p []byte) (int, error) { + l.tb.Log(string(p)) + return len(p), nil +} diff --git a/internal/memhttp/option.go b/internal/memhttp/option.go new file mode 100644 index 00000000..1469943a --- /dev/null +++ b/internal/memhttp/option.go @@ -0,0 +1,67 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memhttp + +import ( + "log" + "time" +) + +type config struct { + DisableTLS bool + DisableHTTP2 bool + CleanupTimeout time.Duration + ErrorLog *log.Logger +} + +// An Option configures a Server. +type Option interface { + apply(*config) +} + +type optionFunc func(*config) + +func (f optionFunc) apply(cfg *config) { f(cfg) } + +// WithoutHTTP2 disables HTTP/2 on the server and client. +func WithoutHTTP2() Option { + return optionFunc(func(cfg *config) { + cfg.DisableHTTP2 = true + }) +} + +// WithOptions composes multiple Options into one. +func WithOptions(opts ...Option) Option { + return optionFunc(func(cfg *config) { + for _, opt := range opts { + opt.apply(cfg) + } + }) +} + +// WithCleanupTimeout customizes the default five-second timeout for the +// server's Cleanup method. It's most useful with the memhttptest subpackage. +func WithCleanupTimeout(d time.Duration) Option { + return optionFunc(func(cfg *config) { + cfg.CleanupTimeout = d + }) +} + +// WithErrorLog sets [http.Server.ErrorLog]. +func WithErrorLog(l *log.Logger) Option { + return optionFunc(func(cfg *config) { + cfg.ErrorLog = l + }) +} diff --git a/recover_ext_test.go b/recover_ext_test.go index 95ec36a2..a385b1d1 100644 --- a/recover_ext_test.go +++ b/recover_ext_test.go @@ -22,9 +22,9 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/connecttest" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" + "connectrpc.com/connect/internal/memhttp/memhttptest" ) type panicPingServer struct { @@ -77,10 +77,10 @@ func TestWithRecover(t *testing.T) { pinger := &panicPingServer{} mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pinger, connect.WithRecover(handle))) - server := connecttest.StartHTTP2TestServer(t, mux) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), - server.URL, + server.URL(), ) for _, panicWith := range []any{42, nil} { From 932533fe9cd478f45d0cb800d02cf0aa44f05342 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 10 Oct 2023 11:53:47 +0100 Subject: [PATCH 06/21] Fix race on RegisterShutdown --- internal/memhttp/memhttp_test.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/internal/memhttp/memhttp_test.go b/internal/memhttp/memhttp_test.go index 07e2b675..e46df3de 100644 --- a/internal/memhttp/memhttp_test.go +++ b/internal/memhttp/memhttp_test.go @@ -147,14 +147,12 @@ func ExampleServer_Shutdown() { _, _ = io.WriteString(w, "Hello, world!") }) srv := memhttp.NewServer(hello) - srv.RegisterOnShutdown(func() { - fmt.Println("Server is shutting down") - }) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := srv.Shutdown(ctx); err != nil { panic(err) } + fmt.Println("Server has shut down") // Output: - // Server is shutting down + // Server has shut down } From c057a629907a18de4aa97fff5c6c4b4b4d429b8a Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 10 Oct 2023 11:59:12 +0100 Subject: [PATCH 07/21] Fix transport desc --- internal/memhttp/memhttp.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/internal/memhttp/memhttp.go b/internal/memhttp/memhttp.go index 2dd6b570..0de60ee6 100644 --- a/internal/memhttp/memhttp.go +++ b/internal/memhttp/memhttp.go @@ -69,12 +69,8 @@ func NewServer(handler http.Handler, opts ...Option) *Server { return server } -// Transport returns an [http.Transport] configured to use in-memory pipes -// rather than TCP, disable automatic compression, trust the server's TLS -// certificate (if any), and use HTTP/2 (if the server supports it). -// -// Callers may reconfigure the returned Transport without affecting other -// transports or clients. +// Transport returns a [http.RoundTripper] configured to use in-memory pipes +// rather than TCP and talk HTTP/2 (if the server supports it). func (s *Server) Transport() http.RoundTripper { if s.disableHTTP2 { return &http.Transport{ From f6a2bd7edbd157f53ad44bb2bd18ee8271cc3e95 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 11 Oct 2023 13:09:38 +0100 Subject: [PATCH 08/21] Fix feedback --- .golangci.yml | 5 +- compression_test.go | 3 +- connect_ext_test.go | 47 +++++++++-------- handler_ext_test.go | 5 +- interceptor_ext_test.go | 5 +- internal/memhttp/listener.go | 57 ++++++++++----------- internal/memhttp/memhttp.go | 75 +++++++++++++++------------- internal/memhttp/memhttp_test.go | 56 +++++++-------------- internal/memhttp/memhttptest/http.go | 2 +- internal/memhttp/option.go | 21 +++----- 10 files changed, 124 insertions(+), 152 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 50ff54e5..cf5c0763 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -122,9 +122,6 @@ issues: # We want to set http.Server's logger - linters: [forbidigo] path: internal/memhttp - # We want to set http.Server's logger - - linters: [forbidigo] - path: internal/memhttp/memhttptest # We want to show examples with http.Get - - linters: [noctx, bodyclose] + - linters: [noctx] path: internal/memhttp/memhttp_test.go diff --git a/compression_test.go b/compression_test.go index 0dc73466..7db457ad 100644 --- a/compression_test.go +++ b/compression_test.go @@ -20,7 +20,6 @@ import ( "testing" "connectrpc.com/connect/internal/assert" - "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" "google.golang.org/protobuf/types/known/emptypb" ) @@ -43,7 +42,7 @@ func TestAcceptEncodingOrdering(t *testing.T) { w.WriteHeader(http.StatusOK) called = true }) - server := memhttptest.NewServer(t, verify, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, verify) client := NewClient[emptypb.Empty, emptypb.Empty]( server.Client(), server.URL(), diff --git a/connect_ext_test.go b/connect_ext_test.go index dcf9ebcd..e8437da9 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -352,10 +352,10 @@ func TestServer(t *testing.T) { assertIsHTTPMiddlewareError(t, stream.Err()) }) } - testMatrix := func(t *testing.T, server *memhttp.Server, bidi bool) { //nolint:thelper + testMatrix := func(t *testing.T, client *http.Client, url string, bidi bool) { //nolint:thelper run := func(t *testing.T, opts ...connect.ClientOption) { t.Helper() - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), opts...) + client := pingv1connect.NewPingServiceClient(client, url, opts...) testPing(t, client) testSum(t, client) testCountUp(t, client) @@ -436,13 +436,15 @@ func TestServer(t *testing.T) { t.Run("http1", func(t *testing.T) { t.Parallel() - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) - testMatrix(t, server, false /* bidi */) + server := memhttptest.NewServer(t, mux) + client := &http.Client{Transport: server.TransportHTTP1()} + testMatrix(t, client, server.URL(), false /* bidi */) }) t.Run("http2", func(t *testing.T) { t.Parallel() server := memhttptest.NewServer(t, mux) - testMatrix(t, server, true /* bidi */) + client := server.Client() + testMatrix(t, client, server.URL(), true /* bidi */) }) } @@ -511,7 +513,7 @@ func TestHeaderBasic(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) request := connect.NewRequest(&pingv1.PingRequest{}) @@ -591,7 +593,7 @@ func TestTimeoutParsing(t *testing.T) { } mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -603,7 +605,7 @@ func TestTimeoutParsing(t *testing.T) { func TestFailCodec(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - server := memhttptest.NewServer(t, handler, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, handler) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), @@ -766,9 +768,9 @@ func TestBidiRequiresHTTP2(t *testing.T) { _, err := io.WriteString(w, "hello world") assert.Nil(t, err) }) - server := memhttptest.NewServer(t, handler, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, handler) client := pingv1connect.NewPingServiceClient( - server.Client(), + &http.Client{Transport: server.TransportHTTP1()}, server.URL(), ) stream := client.CumSum(context.Background()) @@ -831,7 +833,7 @@ func TestCompressMinBytes(t *testing.T) { pingServer{}, connect.WithCompressMinBytes(8), )) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := server.Client() getPingResponse := func(t *testing.T, pingText string) *http.Response { @@ -885,7 +887,7 @@ func TestCustomCompression(t *testing.T) { pingServer{}, connect.WithCompression(compressionName, decompressor, compressor), )) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression(compressionName, decompressor, compressor), @@ -904,7 +906,7 @@ func TestClientWithoutGzipSupport(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithAcceptCompression("gzip", nil, nil), @@ -954,7 +956,7 @@ func TestInterceptorReturnsWrongType(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) { if _, err := next(ctx, request); err != nil { @@ -1659,7 +1661,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { }, } mux.Handle(pingv1connect.NewPingServiceHandler(pluggableServer)) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) req, err := http.NewRequestWithContext( context.Background(), http.MethodPost, @@ -1889,7 +1891,7 @@ func TestConnectProtocolHeaderRequired(t *testing.T) { pingServer{}, connect.WithRequireConnectProtocolHeader(), )) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) tests := []struct { headers http.Header @@ -1928,7 +1930,7 @@ func TestAllowCustomUserAgent(t *testing.T) { return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil }, })) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) // If the user has set a User-Agent, we shouldn't clobber it. tests := []struct { @@ -1964,7 +1966,7 @@ func TestWebXUserAgent(t *testing.T) { return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil }, })) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithGRPCWeb()) req := connect.NewRequest(&pingv1.PingRequest{Number: 42}) @@ -1976,12 +1978,15 @@ func TestBidiOverHTTP1(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{})) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) // Clients expecting a full-duplex connection that end up with a simplex // HTTP/1.1 connection shouldn't hang. Instead, the server should close the // TCP connection. - client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) + client := pingv1connect.NewPingServiceClient( + &http.Client{Transport: server.TransportHTTP1()}, + server.URL(), + ) stream := client.CumSum(context.Background()) if err := stream.Send(&pingv1.CumSumRequest{Number: 2}); err != nil { assert.ErrorIs(t, err, io.EOF) @@ -2021,7 +2026,7 @@ func TestHandlerReturnsNilResponse(t *testing.T) { return nil, nil //nolint: nilnil }, }, connect.WithRecover(recoverPanic))) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient(server.Client(), server.URL()) _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) diff --git a/handler_ext_test.go b/handler_ext_test.go index d7cf7517..25cde595 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -29,7 +29,6 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" - "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" ) @@ -43,7 +42,7 @@ func TestHandler_ServeHTTP(t *testing.T) { mux.Handle("/prefixed/", http.StripPrefix("/prefixed", prefixed)) const pingProcedure = pingv1connect.PingServicePingProcedure const sumProcedure = pingv1connect.PingServiceSumProcedure - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := server.Client() t.Run("get_method_no_encoding", func(t *testing.T) { @@ -215,7 +214,7 @@ func TestHandlerMaliciousPrefix(t *testing.T) { t.Parallel() mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(successPingServer{})) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) const ( concurrency = 256 diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index 8025ac45..a671904b 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -25,7 +25,6 @@ import ( "connectrpc.com/connect/internal/assert" pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" - "connectrpc.com/connect/internal/memhttp" "connectrpc.com/connect/internal/memhttp/memhttptest" ) @@ -128,7 +127,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { handlerOnion, ), ) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) client := pingv1connect.NewPingServiceClient( server.Client(), server.URL(), @@ -173,7 +172,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { } }) mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithInterceptors(interceptor))) - server := memhttptest.NewServer(t, mux, memhttp.WithoutHTTP2()) + server := memhttptest.NewServer(t, mux) connectClient := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), connect.WithInterceptors(interceptor)) _, err := connectClient.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{})) assert.Nil(t, err) diff --git a/internal/memhttp/listener.go b/internal/memhttp/listener.go index 773dd545..adec7519 100644 --- a/internal/memhttp/listener.go +++ b/internal/memhttp/listener.go @@ -21,41 +21,45 @@ import ( "sync" ) -// MemoryListener is a net.Listener that listens on an in memory network. -type MemoryListener struct { +var ( + errListenerClosed = errors.New("listener closed") +) + +// memoryListener is a net.Listener that listens on an in memory network. +type memoryListener struct { + addr memoryAddr + conns chan net.Conn once sync.Once closed chan struct{} } -// NewMemoryListener returns a new in-memory listener. -func NewMemoryListener() *MemoryListener { - return &MemoryListener{ +// newMemoryListener returns a new in-memory listener. +func newMemoryListener(addr string) *memoryListener { + return &memoryListener{ + addr: memoryAddr(addr), conns: make(chan net.Conn), closed: make(chan struct{}), } } // Accept implements net.Listener. -func (l *MemoryListener) Accept() (net.Conn, error) { - aerr := func(err error) error { - return &net.OpError{ - Op: "accept", - Net: memoryAddr{}.Network(), - Addr: memoryAddr{}, - Err: err, - } - } +func (l *memoryListener) Accept() (net.Conn, error) { select { case <-l.closed: - return nil, aerr(errors.New("listener closed")) + return nil, &net.OpError{ + Op: "accept", + Net: l.addr.Network(), + Addr: l.addr, + Err: errListenerClosed, + } case server := <-l.conns: return server, nil } } // Close implements net.Listener. -func (l *MemoryListener) Close() error { +func (l *memoryListener) Close() error { l.once.Do(func() { close(l.closed) }) @@ -63,35 +67,28 @@ func (l *MemoryListener) Close() error { } // Addr implements net.Listener. -func (l *MemoryListener) Addr() net.Addr { - return &memoryAddr{} +func (l *memoryListener) Addr() net.Addr { + return l.addr } // DialContext is the type expected by http.Transport.DialContext. -func (l *MemoryListener) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - derr := func(err error) error { - return &net.OpError{ - Op: "dial", - Net: memoryAddr{}.Network(), - Err: err, - } - } +func (l *memoryListener) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { server, client := net.Pipe() select { case <-ctx.Done(): - return nil, derr(ctx.Err()) + return nil, &net.OpError{Op: "dial", Net: l.addr.Network(), Err: ctx.Err()} case l.conns <- server: return client, nil case <-l.closed: - return nil, derr(errors.New("listener closed")) + return nil, &net.OpError{Op: "dial", Net: l.addr.Network(), Err: errListenerClosed} } } -type memoryAddr struct{} +type memoryAddr string // Network implements net.Addr. func (memoryAddr) Network() string { return "memory" } // String implements io.Stringer, returning a value that matches the // certificates used by net/http/httptest. -func (memoryAddr) String() string { return "example.com" } +func (a memoryAddr) String() string { return string(a) } diff --git a/internal/memhttp/memhttp.go b/internal/memhttp/memhttp.go index 0de60ee6..89c2d4e9 100644 --- a/internal/memhttp/memhttp.go +++ b/internal/memhttp/memhttp.go @@ -31,11 +31,10 @@ import ( // default, it supports http/2 via h2c. It otherwise uses the same configuration // as the zero value of [http.Server]. type Server struct { - server http.Server - listener *MemoryListener - url string - cleanupTimeout time.Duration - disableHTTP2 bool + server http.Server + listener *memoryListener + url string + shutdownTimeout time.Duration serverWG sync.WaitGroup serverErr error @@ -45,38 +44,36 @@ type Server struct { // options may be provided via [Option]s. func NewServer(handler http.Handler, opts ...Option) *Server { var cfg config - WithCleanupTimeout(5 * time.Second).apply(&cfg) + WithShutdownTimeout(5 * time.Second).apply(&cfg) for _, opt := range opts { opt.apply(&cfg) } - if !cfg.DisableHTTP2 { - h2s := &http2.Server{} - handler = h2c.NewHandler(handler, h2s) - } - listener := NewMemoryListener() + h2s := &http2.Server{} + handler = h2c.NewHandler(handler, h2s) + listener := newMemoryListener("1.2.3.4") // httptest.DefaultRemoteAddr server := &Server{ server: http.Server{ Handler: handler, ReadHeaderTimeout: 5 * time.Second, }, - listener: listener, - cleanupTimeout: cfg.CleanupTimeout, - url: "http://" + listener.Addr().String(), - disableHTTP2: cfg.DisableHTTP2, + listener: listener, + shutdownTimeout: cfg.ShutdownTimeout, + url: "http://" + listener.Addr().String(), } - server.goServe() + server.serverWG.Add(1) + go func() { + defer server.serverWG.Done() + server.serverErr = server.server.Serve(server.listener) + }() return server } -// Transport returns a [http.RoundTripper] configured to use in-memory pipes -// rather than TCP and talk HTTP/2 (if the server supports it). -func (s *Server) Transport() http.RoundTripper { - if s.disableHTTP2 { - return &http.Transport{ - DialContext: s.listener.DialContext, - } - } +// Transport returns a [http2.Transport] configured to use in-memory pipes +// rather than TCP and speak both HTTP/1.1 and HTTP/2. +// +// Callers may reconfigure the returned transport without affecting other transports. +func (s *Server) Transport() *http2.Transport { return &http2.Transport{ DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { return s.listener.DialContext(ctx, network, addr) @@ -85,9 +82,22 @@ func (s *Server) Transport() http.RoundTripper { } } +// TransportHTTP1 returns a [http.Transport] configured to use in-memory pipes +// rather than TCP and speak HTTP/1.1. +// +// Callers may reconfigure the returned transport without affecting other transports. +func (s *Server) TransportHTTP1() *http.Transport { + return &http.Transport{ + DialContext: s.listener.DialContext, + // TODO(emcfarlane): DisableKeepAlives false can causes tests + // to hang on shutdown. + DisableKeepAlives: true, + } +} + // Client returns an [http.Client] configured to use in-memory pipes rather -// than TCP, disable automatic compression, trust the server's TLS certificate -// (if any), and use HTTP/2 (if the server supports it). +// than TCP, and speak HTTP/2. It is configured to use the same +// [http2.Transport] as [Transport]. // // Callers may reconfigure the returned client without affecting other clients. func (s *Server) Client() *http.Client { @@ -102,7 +112,7 @@ func (s *Server) URL() string { // Shutdown gracefully shuts down the server, without interrupting any active // connections. See [http.Server.Shutdown] for details. func (s *Server) Shutdown(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, s.cleanupTimeout) + ctx, cancel := context.WithTimeout(ctx, s.shutdownTimeout) defer cancel() if err := s.server.Shutdown(ctx); err != nil { return err @@ -122,7 +132,8 @@ func (s *Server) RegisterOnShutdown(f func()) { s.server.RegisterOnShutdown(f) } -// Wait blocks until the server exits, then returns its error. +// Wait blocks until the server exits, then returns an error if not +// a [http.ErrServerClosed] error. func (s *Server) Wait() error { s.serverWG.Wait() if !errors.Is(s.serverErr, http.ErrServerClosed) { @@ -130,11 +141,3 @@ func (s *Server) Wait() error { } return nil } - -func (s *Server) goServe() { - s.serverWG.Add(1) - go func() { - defer s.serverWG.Done() - s.serverErr = s.server.Serve(s.listener) - }() -} diff --git a/internal/memhttp/memhttp_test.go b/internal/memhttp/memhttp_test.go index e46df3de..06e11ee4 100644 --- a/internal/memhttp/memhttp_test.go +++ b/internal/memhttp/memhttp_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "net/http" - "strings" "sync" "testing" "time" @@ -29,53 +28,34 @@ import ( "connectrpc.com/connect/internal/memhttp/memhttptest" ) -func TestServer(t *testing.T) { +func TestServerTransport(t *testing.T) { t.Parallel() - tests := []struct { - name string - opts []memhttp.Option - handler func(t *testing.T, w http.ResponseWriter, r *http.Request) - }{{ - name: "http2", - opts: nil, - handler: func(t *testing.T, _ http.ResponseWriter, r *http.Request) { - t.Helper() - assert.Equal(t, r.ProtoMajor, 2) - assert.Equal(t, r.ProtoMinor, 0) - }, - }, { - name: "http1", - opts: []memhttp.Option{memhttp.WithoutHTTP2()}, - handler: func(t *testing.T, _ http.ResponseWriter, r *http.Request) { - t.Helper() - assert.Equal(t, r.ProtoMajor, 1) - assert.Equal(t, r.ProtoMinor, 1) - }, - }} - for _, testcase := range tests { - testcase := testcase - t.Run(testcase.name, func(t *testing.T) { - t.Parallel() - const concurrency = 100 - const greeting = "Hello, world!" + const concurrency = 100 + const greeting = "Hello, world!" - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - testcase.handler(t, w, r) - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(greeting)) - }) - server := memhttptest.NewServer(t, handler, testcase.opts...) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(greeting)) + }) + server := memhttptest.NewServer(t, handler) + + for _, transport := range []http.RoundTripper{ + server.Transport(), + server.TransportHTTP1(), + } { + client := &http.Client{Transport: transport} + t.Run(fmt.Sprintf("%T", transport), func(t *testing.T) { + t.Parallel() var wg sync.WaitGroup for i := 0; i < concurrency; i++ { wg.Add(1) go func() { defer wg.Done() - client := server.Client() req, err := http.NewRequestWithContext( context.Background(), http.MethodGet, server.URL(), - strings.NewReader(""), + nil, ) assert.Nil(t, err) res, err := client.Do(req) @@ -120,6 +100,7 @@ func Example() { if err != nil { panic(err) } + defer res.Body.Close() fmt.Println(res.Status) // Output: // 200 OK @@ -137,6 +118,7 @@ func ExampleServer_Client() { if err != nil { panic(err) } + defer res.Body.Close() fmt.Println(res.Status) // Output: // 200 OK diff --git a/internal/memhttp/memhttptest/http.go b/internal/memhttp/memhttptest/http.go index c0c68966..5308c2c1 100644 --- a/internal/memhttp/memhttptest/http.go +++ b/internal/memhttp/memhttptest/http.go @@ -34,7 +34,7 @@ import ( func NewServer(tb testing.TB, handler http.Handler, opts ...memhttp.Option) *memhttp.Server { tb.Helper() logger := log.New(&testWriter{tb}, "" /* prefix */, log.Lshortfile) - opts = append(opts, memhttp.WithErrorLog(logger)) + opts = append([]memhttp.Option{memhttp.WithErrorLog(logger)}, opts...) server := memhttp.NewServer(handler, opts...) tb.Cleanup(func() { tb.Logf("shutting down server") diff --git a/internal/memhttp/option.go b/internal/memhttp/option.go index 1469943a..ae9439ba 100644 --- a/internal/memhttp/option.go +++ b/internal/memhttp/option.go @@ -20,10 +20,8 @@ import ( ) type config struct { - DisableTLS bool - DisableHTTP2 bool - CleanupTimeout time.Duration - ErrorLog *log.Logger + ShutdownTimeout time.Duration + ErrorLog *log.Logger } // An Option configures a Server. @@ -35,13 +33,6 @@ type optionFunc func(*config) func (f optionFunc) apply(cfg *config) { f(cfg) } -// WithoutHTTP2 disables HTTP/2 on the server and client. -func WithoutHTTP2() Option { - return optionFunc(func(cfg *config) { - cfg.DisableHTTP2 = true - }) -} - // WithOptions composes multiple Options into one. func WithOptions(opts ...Option) Option { return optionFunc(func(cfg *config) { @@ -51,11 +42,11 @@ func WithOptions(opts ...Option) Option { }) } -// WithCleanupTimeout customizes the default five-second timeout for the -// server's Cleanup method. It's most useful with the memhttptest subpackage. -func WithCleanupTimeout(d time.Duration) Option { +// WithShutdownTimeout customizes the default five-second timeout for the +// server's Shutdown method. +func WithShutdownTimeout(d time.Duration) Option { return optionFunc(func(cfg *config) { - cfg.CleanupTimeout = d + cfg.ShutdownTimeout = d }) } From 40594b78a04b42ac5abeaba291f5cc58d00c0718 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 11 Oct 2023 13:13:43 +0100 Subject: [PATCH 09/21] Fix description --- internal/memhttp/memhttp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/memhttp/memhttp.go b/internal/memhttp/memhttp.go index 89c2d4e9..ae9a6490 100644 --- a/internal/memhttp/memhttp.go +++ b/internal/memhttp/memhttp.go @@ -96,7 +96,7 @@ func (s *Server) TransportHTTP1() *http.Transport { } // Client returns an [http.Client] configured to use in-memory pipes rather -// than TCP, and speak HTTP/2. It is configured to use the same +// than TCP and speak HTTP/2. It is configured to use the same // [http2.Transport] as [Transport]. // // Callers may reconfigure the returned client without affecting other clients. From b9fccc6d440aacc03a6de89a62d4c9887825c05b Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 13 Oct 2023 14:12:22 +0100 Subject: [PATCH 10/21] Add Cleanup method test servers --- internal/memhttp/internal/config.go | 35 +++++++++++++++++++++++++ internal/memhttp/memhttp.go | 32 ++++++++++++++--------- internal/memhttp/memhttptest/http.go | 4 +-- internal/memhttp/option.go | 38 ++++++++++------------------ 4 files changed, 70 insertions(+), 39 deletions(-) create mode 100644 internal/memhttp/internal/config.go diff --git a/internal/memhttp/internal/config.go b/internal/memhttp/internal/config.go new file mode 100644 index 00000000..e0ed81e4 --- /dev/null +++ b/internal/memhttp/internal/config.go @@ -0,0 +1,35 @@ +// Copyright 2021-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "log" + "time" +) + +// Config is the configuration for a Server. +type Config struct { + CleanupTimeout time.Duration + ErrorLog *log.Logger +} + +// An Option configures a Server. +type Option interface { + Apply(*Config) +} + +type OptionFunc func(*Config) + +func (f OptionFunc) Apply(cfg *Config) { f(cfg) } diff --git a/internal/memhttp/memhttp.go b/internal/memhttp/memhttp.go index ae9a6490..89ad93ce 100644 --- a/internal/memhttp/memhttp.go +++ b/internal/memhttp/memhttp.go @@ -23,6 +23,7 @@ import ( "sync" "time" + "connectrpc.com/connect/internal/memhttp/internal" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -31,10 +32,10 @@ import ( // default, it supports http/2 via h2c. It otherwise uses the same configuration // as the zero value of [http.Server]. type Server struct { - server http.Server - listener *memoryListener - url string - shutdownTimeout time.Duration + server http.Server + listener *memoryListener + url string + cleanupTimeout time.Duration serverWG sync.WaitGroup serverErr error @@ -43,10 +44,10 @@ type Server struct { // NewServer creates a new Server that uses the given handler. Configuration // options may be provided via [Option]s. func NewServer(handler http.Handler, opts ...Option) *Server { - var cfg config - WithShutdownTimeout(5 * time.Second).apply(&cfg) + var cfg internal.Config + WithCleanupTimeout(5 * time.Second).Apply(&cfg) for _, opt := range opts { - opt.apply(&cfg) + opt.Apply(&cfg) } h2s := &http2.Server{} @@ -57,9 +58,9 @@ func NewServer(handler http.Handler, opts ...Option) *Server { Handler: handler, ReadHeaderTimeout: 5 * time.Second, }, - listener: listener, - shutdownTimeout: cfg.ShutdownTimeout, - url: "http://" + listener.Addr().String(), + listener: listener, + url: "http://" + listener.Addr().String(), + cleanupTimeout: cfg.CleanupTimeout, } server.serverWG.Add(1) go func() { @@ -112,14 +113,21 @@ func (s *Server) URL() string { // Shutdown gracefully shuts down the server, without interrupting any active // connections. See [http.Server.Shutdown] for details. func (s *Server) Shutdown(ctx context.Context) error { - ctx, cancel := context.WithTimeout(ctx, s.shutdownTimeout) - defer cancel() if err := s.server.Shutdown(ctx); err != nil { return err } return s.Wait() } +// Cleanup calls shutdown with a background context set with the cleanup timeout. +// The default timeout duration is 5 seconds. +func (s *Server) Cleanup() error { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, s.cleanupTimeout) + defer cancel() + return s.Shutdown(ctx) +} + // Close closes the server's listener. It does not wait for connections to // finish. func (s *Server) Close() error { diff --git a/internal/memhttp/memhttptest/http.go b/internal/memhttp/memhttptest/http.go index 5308c2c1..f6d90c55 100644 --- a/internal/memhttp/memhttptest/http.go +++ b/internal/memhttp/memhttptest/http.go @@ -15,7 +15,6 @@ package memhttptest import ( - "context" "log" "net/http" "testing" @@ -37,8 +36,7 @@ func NewServer(tb testing.TB, handler http.Handler, opts ...memhttp.Option) *mem opts = append([]memhttp.Option{memhttp.WithErrorLog(logger)}, opts...) server := memhttp.NewServer(handler, opts...) tb.Cleanup(func() { - tb.Logf("shutting down server") - if err := server.Shutdown(context.Background()); err != nil { + if err := server.Cleanup(); err != nil { tb.Error(err) } }) diff --git a/internal/memhttp/option.go b/internal/memhttp/option.go index ae9439ba..135bd888 100644 --- a/internal/memhttp/option.go +++ b/internal/memhttp/option.go @@ -17,42 +17,32 @@ package memhttp import ( "log" "time" -) - -type config struct { - ShutdownTimeout time.Duration - ErrorLog *log.Logger -} -// An Option configures a Server. -type Option interface { - apply(*config) -} - -type optionFunc func(*config) + "connectrpc.com/connect/internal/memhttp/internal" +) -func (f optionFunc) apply(cfg *config) { f(cfg) } +type Option = internal.Option // WithOptions composes multiple Options into one. func WithOptions(opts ...Option) Option { - return optionFunc(func(cfg *config) { + return internal.OptionFunc(func(cfg *internal.Config) { for _, opt := range opts { - opt.apply(cfg) + opt.Apply(cfg) } }) } -// WithShutdownTimeout customizes the default five-second timeout for the -// server's Shutdown method. -func WithShutdownTimeout(d time.Duration) Option { - return optionFunc(func(cfg *config) { - cfg.ShutdownTimeout = d - }) -} - // WithErrorLog sets [http.Server.ErrorLog]. func WithErrorLog(l *log.Logger) Option { - return optionFunc(func(cfg *config) { + return internal.OptionFunc(func(cfg *internal.Config) { cfg.ErrorLog = l }) } + +// WithCleanupTimeout customizes the default five-second timeout for the +// server's Cleanup method. +func WithCleanupTimeout(d time.Duration) Option { + return internal.OptionFunc(func(cfg *internal.Config) { + cfg.CleanupTimeout = d + }) +} From dd98dd482e8bcb62f289c97db037394a64d7171d Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 13 Oct 2023 17:53:54 +0100 Subject: [PATCH 11/21] Revert moving options --- internal/memhttp/internal/config.go | 35 ----------------------------- internal/memhttp/memhttp.go | 7 +++--- internal/memhttp/option.go | 25 +++++++++++++++------ 3 files changed, 21 insertions(+), 46 deletions(-) delete mode 100644 internal/memhttp/internal/config.go diff --git a/internal/memhttp/internal/config.go b/internal/memhttp/internal/config.go deleted file mode 100644 index e0ed81e4..00000000 --- a/internal/memhttp/internal/config.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2021-2023 The Connect Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "log" - "time" -) - -// Config is the configuration for a Server. -type Config struct { - CleanupTimeout time.Duration - ErrorLog *log.Logger -} - -// An Option configures a Server. -type Option interface { - Apply(*Config) -} - -type OptionFunc func(*Config) - -func (f OptionFunc) Apply(cfg *Config) { f(cfg) } diff --git a/internal/memhttp/memhttp.go b/internal/memhttp/memhttp.go index 89ad93ce..67de9935 100644 --- a/internal/memhttp/memhttp.go +++ b/internal/memhttp/memhttp.go @@ -23,7 +23,6 @@ import ( "sync" "time" - "connectrpc.com/connect/internal/memhttp/internal" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -44,10 +43,10 @@ type Server struct { // NewServer creates a new Server that uses the given handler. Configuration // options may be provided via [Option]s. func NewServer(handler http.Handler, opts ...Option) *Server { - var cfg internal.Config - WithCleanupTimeout(5 * time.Second).Apply(&cfg) + var cfg config + WithCleanupTimeout(5 * time.Second).apply(&cfg) for _, opt := range opts { - opt.Apply(&cfg) + opt.apply(&cfg) } h2s := &http2.Server{} diff --git a/internal/memhttp/option.go b/internal/memhttp/option.go index 135bd888..b3e972d5 100644 --- a/internal/memhttp/option.go +++ b/internal/memhttp/option.go @@ -17,24 +17,35 @@ package memhttp import ( "log" "time" - - "connectrpc.com/connect/internal/memhttp/internal" ) -type Option = internal.Option +// config is the configuration for a Server. +type config struct { + CleanupTimeout time.Duration + ErrorLog *log.Logger +} + +// An Option configures a Server. +type Option interface { + apply(*config) +} + +type optionFunc func(*config) + +func (f optionFunc) apply(cfg *config) { f(cfg) } // WithOptions composes multiple Options into one. func WithOptions(opts ...Option) Option { - return internal.OptionFunc(func(cfg *internal.Config) { + return optionFunc(func(cfg *config) { for _, opt := range opts { - opt.Apply(cfg) + opt.apply(cfg) } }) } // WithErrorLog sets [http.Server.ErrorLog]. func WithErrorLog(l *log.Logger) Option { - return internal.OptionFunc(func(cfg *internal.Config) { + return optionFunc(func(cfg *config) { cfg.ErrorLog = l }) } @@ -42,7 +53,7 @@ func WithErrorLog(l *log.Logger) Option { // WithCleanupTimeout customizes the default five-second timeout for the // server's Cleanup method. func WithCleanupTimeout(d time.Duration) Option { - return internal.OptionFunc(func(cfg *internal.Config) { + return optionFunc(func(cfg *config) { cfg.CleanupTimeout = d }) } From 716794bb28bca01d77ca4416dcd75d2b8eb31bd4 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Sun, 22 Oct 2023 15:43:58 -0400 Subject: [PATCH 12/21] Ensure response errors are reported consistently Removes SetError in favour of reporting errors on BlockUntilResponseReady. ensureRequestMade removes sync.Once as the contract of ClientStream state calling Send with CloseRequest is not safe to call concurrently. --- client_ext_test.go | 4 +- connect_ext_test.go | 14 ++++--- duplex_http_call.go | 95 ++++++++++++++++++--------------------------- protocol_connect.go | 9 +++-- protocol_grpc.go | 13 ++++--- 5 files changed, 58 insertions(+), 77 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 36908a1d..110edb1b 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -106,10 +106,8 @@ func TestClientPeer(t *testing.T) { err = clientStream.Send(&pingv1.SumRequest{}) assert.Nil(t, err) // server streaming - serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) + serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1})) t.Cleanup(func() { - // TODO(emcfarlane): debug flaky test close with error: - // "unknown: io: read/write on closed pipe" assert.Nil(t, serverStream.Close()) }) assert.Nil(t, err) diff --git a/connect_ext_test.go b/connect_ext_test.go index e8437da9..24d85e26 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -774,15 +774,15 @@ func TestBidiRequiresHTTP2(t *testing.T) { server.URL(), ) stream := client.CumSum(context.Background()) - // Stream creates an async request, can error on Send or Receive. - err := stream.Send(&pingv1.CumSumRequest{}) - if err == nil { - assert.Nil(t, stream.CloseRequest()) - _, err = stream.Receive() + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { + assert.ErrorIs(t, err, io.EOF) } + assert.Nil(t, stream.CloseRequest()) + _, err := stream.Receive() assert.NotNil(t, err) var connectErr *connect.Error assert.True(t, errors.As(err, &connectErr)) + t.Log(err) assert.Equal(t, connectErr.Code(), connect.CodeUnimplemented) assert.True( t, @@ -1988,13 +1988,14 @@ func TestBidiOverHTTP1(t *testing.T) { server.URL(), ) stream := client.CumSum(context.Background()) + // Stream creates an async request, can error on Send or Receive. if err := stream.Send(&pingv1.CumSumRequest{Number: 2}); err != nil { assert.ErrorIs(t, err, io.EOF) } _, err := stream.Receive() assert.NotNil(t, err) assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown) - assert.Equal(t, err.Error(), "unknown: HTTP status 505 HTTP Version Not Supported") + assert.True(t, strings.HasSuffix(err.Error(), "HTTP status 505 HTTP Version Not Supported")) assert.Nil(t, stream.CloseRequest()) assert.Nil(t, stream.CloseResponse()) } @@ -2342,6 +2343,7 @@ func (p *pluggablePingServer) CumSum( func failNoHTTP2(tb testing.TB, stream *connect.BidiStreamForClient[pingv1.CumSumRequest, pingv1.CumSumResponse]) { tb.Helper() + if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { assert.ErrorIs(tb, err, io.EOF) assert.Equal(tb, connect.CodeOf(err), connect.CodeUnknown) diff --git a/duplex_http_call.go b/duplex_http_call.go index ab1a6db4..f57c8d82 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -42,13 +42,11 @@ type duplexHTTPCall struct { requestBodyReader *io.PipeReader requestBodyWriter *io.PipeWriter - sendRequestOnce sync.Once - responseReady chan struct{} - request *http.Request - response *http.Response - - errMu sync.Mutex - err error + requestSent bool + responseReady sync.WaitGroup + request *http.Request + response *http.Response + responseErr error } func newDuplexHTTPCall( @@ -80,24 +78,23 @@ func newDuplexHTTPCall( Body: pipeReader, Host: url.Host, }).WithContext(ctx) - return &duplexHTTPCall{ + call := &duplexHTTPCall{ ctx: ctx, httpClient: httpClient, streamType: spec.StreamType, requestBodyReader: pipeReader, requestBodyWriter: pipeWriter, request: request, - responseReady: make(chan struct{}), } + call.responseReady.Add(1) + return call } -// Write to the request body. Returns an error wrapping io.EOF after SetError -// is called. +// Write to the request body. func (d *duplexHTTPCall) Write(data []byte) (int, error) { d.ensureRequestMade() // Before we send any data, check if the context has been canceled. if err := d.ctx.Err(); err != nil { - d.SetError(err) return 0, wrapIfContextError(err) } // It's safe to write to this side of the pipe while net/http concurrently @@ -157,14 +154,12 @@ func (d *duplexHTTPCall) SetMethod(method string) { func (d *duplexHTTPCall) Read(data []byte) (int, error) { // First, we wait until we've gotten the response headers and established the // server-to-client side of the stream. - d.BlockUntilResponseReady() - if err := d.getError(); err != nil { + if err := d.BlockUntilResponseReady(); err != nil { // The stream is already closed or corrupted. return 0, err } // Before we read, check if the context has been canceled. if err := d.ctx.Err(); err != nil { - d.SetError(err) return 0, wrapIfContextError(err) } if d.response == nil { @@ -175,7 +170,7 @@ func (d *duplexHTTPCall) Read(data []byte) (int, error) { } func (d *duplexHTTPCall) CloseRead() error { - d.BlockUntilResponseReady() + d.responseReady.Wait() if d.response == nil { return nil } @@ -188,7 +183,9 @@ func (d *duplexHTTPCall) CloseRead() error { // ResponseStatusCode is the response's HTTP status code. func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { - d.BlockUntilResponseReady() + if err := d.BlockUntilResponseReady(); err != nil { + return 0, err + } if d.response == nil { return 0, fmt.Errorf("nil response from %v", d.request.URL) } @@ -197,7 +194,7 @@ func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { // ResponseHeader returns the response HTTP headers. func (d *duplexHTTPCall) ResponseHeader() http.Header { - d.BlockUntilResponseReady() + _ = d.BlockUntilResponseReady() if d.response != nil { return d.response.Header } @@ -206,56 +203,39 @@ func (d *duplexHTTPCall) ResponseHeader() http.Header { // ResponseTrailer returns the response HTTP trailers. func (d *duplexHTTPCall) ResponseTrailer() http.Header { - d.BlockUntilResponseReady() + _ = d.BlockUntilResponseReady() if d.response != nil { return d.response.Trailer } return make(http.Header) } -// SetError stores any error encountered processing the response. All -// subsequent calls to Read return this error, and all subsequent calls to -// Write return an error wrapping io.EOF. It's safe to call concurrently with -// any other method. -func (d *duplexHTTPCall) SetError(err error) { - d.errMu.Lock() - if d.err == nil { - d.err = wrapIfContextError(err) - } - // Closing the read side of the request body pipe acquires an internal lock, - // so we want to scope errMu's usage narrowly and avoid defer. - d.errMu.Unlock() - - // We've already hit an error, so we should stop writing to the request body. - // It's safe to call Close more than once and/or concurrently (calls after - // the first are no-ops), so it's okay for us to call this even though - // net/http sometimes closes the reader too. - // - // It's safe to ignore the returned error here. Under the hood, Close calls - // CloseWithError, which is documented to always return nil. - _ = d.requestBodyReader.Close() -} - // SetValidateResponse sets the response validation function. The function runs // in a background goroutine. func (d *duplexHTTPCall) SetValidateResponse(validate func(*http.Response) *Error) { d.validateResponse = validate } -func (d *duplexHTTPCall) BlockUntilResponseReady() { - <-d.responseReady +func (d *duplexHTTPCall) BlockUntilResponseReady() error { + d.responseReady.Wait() + return d.responseErr } +// ensureRequestMade sends the request headers and starts the response stream. +// It is not safe to call this concurrently. Write and CloseWrite call this but +// ensure that they're not called concurrently. func (d *duplexHTTPCall) ensureRequestMade() { - d.sendRequestOnce.Do(func() { - go d.makeRequest() - }) + if d.requestSent { + return // already sent + } + d.requestSent = true + go d.makeRequest() } func (d *duplexHTTPCall) makeRequest() { // This runs concurrently with Write and CloseWrite. Read and CloseRead wait // on d.responseReady, so we can't race with them. - defer close(d.responseReady) + defer d.responseReady.Done() // Promote the header Host to the request object. if host := d.request.Header.Get(headerHost); len(host) > 0 { @@ -276,33 +256,32 @@ func (d *duplexHTTPCall) makeRequest() { if _, ok := asError(err); !ok { err = NewError(CodeUnavailable, err) } - d.SetError(err) + d.responseErr = err + d.requestBodyReader.CloseWithError(io.EOF) return } d.response = response if err := d.validateResponse(response); err != nil { - d.SetError(err) + d.responseErr = err + d.response.Body.Close() + d.requestBodyReader.CloseWithError(io.EOF) return } if (d.streamType&StreamTypeBidi) == StreamTypeBidi && response.ProtoMajor < 2 { // If we somehow dialed an HTTP/1.x server, fail with an explicit message // rather than returning a more cryptic error later on. - d.SetError(errorf( + d.responseErr = errorf( CodeUnimplemented, "response from %v is HTTP/%d.%d: bidi streams require at least HTTP/2", d.request.URL, response.ProtoMajor, response.ProtoMinor, - )) + ) + d.response.Body.Close() + d.requestBodyReader.CloseWithError(io.EOF) } } -func (d *duplexHTTPCall) getError() error { - d.errMu.Lock() - defer d.errMu.Unlock() - return d.err -} - // See: https://cs.opensource.google/go/go/+/refs/tags/go1.20.1:src/net/http/clone.go;l=22-33 func cloneURL(oldURL *url.URL) *url.URL { if oldURL == nil { diff --git a/protocol_connect.go b/protocol_connect.go index e3c74923..cab89c08 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -589,7 +589,9 @@ func (cc *connectStreamingClientConn) CloseRequest() error { } func (cc *connectStreamingClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() + if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { + return err + } err := cc.unmarshaler.Unmarshal(msg) if err == nil { return nil @@ -603,7 +605,6 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) return serverErr } // If the error is EOF but not from a last message, we want to return @@ -614,8 +615,8 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // There's no error in the trailers, so this was probably an error // converting the bytes to a message, an error reading from the network, or // just an EOF. We're going to return it to the user, but we also want to - // setResponseError so Send errors out. - cc.duplexCall.SetError(err) + // close the writer so Send errors out. + _ = cc.duplexCall.CloseWrite() return err } diff --git a/protocol_grpc.go b/protocol_grpc.go index d3cb0062..7cdc411f 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -381,7 +381,9 @@ func (cc *grpcClientConn) CloseRequest() error { } func (cc *grpcClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() + if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { + return err + } err := cc.unmarshaler.Unmarshal(msg) if err == nil { return nil @@ -409,23 +411,22 @@ func (cc *grpcClientConn) Receive(msg any) error { // the stream has ended, Receive must return an error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) return serverErr } // This was probably an error converting the bytes to a message or an error // reading from the network. We're going to return it to the - // user, but we also want to setResponseError so Send errors out. - cc.duplexCall.SetError(err) + // user, but we also want to close writes so Send errors out. + _ = cc.duplexCall.CloseWrite() return err } func (cc *grpcClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseHeader } func (cc *grpcClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseTrailer } From ccd39d456b8536f5c313f9fb7ff72e5a59fb4790 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Sun, 22 Oct 2023 15:53:05 -0400 Subject: [PATCH 13/21] Document BlockUntilResponseReady behaviour --- duplex_http_call.go | 6 ++++-- protocol_connect.go | 12 +++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/duplex_http_call.go b/duplex_http_call.go index f57c8d82..6112cac4 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -194,7 +194,7 @@ func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { // ResponseHeader returns the response HTTP headers. func (d *duplexHTTPCall) ResponseHeader() http.Header { - _ = d.BlockUntilResponseReady() + d.responseReady.Wait() if d.response != nil { return d.response.Header } @@ -203,7 +203,7 @@ func (d *duplexHTTPCall) ResponseHeader() http.Header { // ResponseTrailer returns the response HTTP trailers. func (d *duplexHTTPCall) ResponseTrailer() http.Header { - _ = d.BlockUntilResponseReady() + d.responseReady.Wait() if d.response != nil { return d.response.Trailer } @@ -216,6 +216,8 @@ func (d *duplexHTTPCall) SetValidateResponse(validate func(*http.Response) *Erro d.validateResponse = validate } +// BlockUntilResponseReady returns when the response is ready or reports an +// error from initializing the request. func (d *duplexHTTPCall) BlockUntilResponseReady() error { d.responseReady.Wait() return d.responseErr diff --git a/protocol_connect.go b/protocol_connect.go index cab89c08..59ba8870 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -478,7 +478,9 @@ func (cc *connectUnaryClientConn) CloseRequest() error { } func (cc *connectUnaryClientConn) Receive(msg any) error { - cc.duplexCall.BlockUntilResponseReady() + if err := cc.duplexCall.BlockUntilResponseReady(); err != nil { + return err + } if err := cc.unmarshaler.Unmarshal(msg); err != nil { return err } @@ -486,12 +488,12 @@ func (cc *connectUnaryClientConn) Receive(msg any) error { } func (cc *connectUnaryClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseHeader } func (cc *connectUnaryClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseTrailer } @@ -621,12 +623,12 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { } func (cc *connectStreamingClientConn) ResponseHeader() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseHeader } func (cc *connectStreamingClientConn) ResponseTrailer() http.Header { - cc.duplexCall.BlockUntilResponseReady() + _ = cc.duplexCall.BlockUntilResponseReady() return cc.responseTrailer } From 91afa1e58d0aa728accdc82005fe6b5abf611a6d Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 23 Oct 2023 11:18:35 -0400 Subject: [PATCH 14/21] Ensure CloseWrite is called --- duplex_http_call.go | 3 --- protocol_connect.go | 1 + protocol_grpc.go | 1 + 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/duplex_http_call.go b/duplex_http_call.go index 6112cac4..4a46c4ff 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -186,9 +186,6 @@ func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { if err := d.BlockUntilResponseReady(); err != nil { return 0, err } - if d.response == nil { - return 0, fmt.Errorf("nil response from %v", d.request.URL) - } return d.response.StatusCode, nil } diff --git a/protocol_connect.go b/protocol_connect.go index 59ba8870..00f2f66b 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -607,6 +607,7 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) + _ = cc.duplexCall.CloseWrite() return serverErr } // If the error is EOF but not from a last message, we want to return diff --git a/protocol_grpc.go b/protocol_grpc.go index 7cdc411f..c4b661e5 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -411,6 +411,7 @@ func (cc *grpcClientConn) Receive(msg any) error { // the stream has ended, Receive must return an error. serverErr.meta = cc.responseHeader.Clone() mergeHeaders(serverErr.meta, cc.responseTrailer) + _ = cc.duplexCall.CloseWrite() return serverErr } // This was probably an error converting the bytes to a message or an error From ea743b32cf167d22770ea2cd1c0fa6121e196892 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 23 Oct 2023 11:42:23 -0400 Subject: [PATCH 15/21] Feedback remove changes to duplexHTTPCall --- connect_ext_test.go | 3 +-- duplex_http_call.go | 35 ++++++++++++++++++----------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 24d85e26..1eaca691 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -782,7 +782,6 @@ func TestBidiRequiresHTTP2(t *testing.T) { assert.NotNil(t, err) var connectErr *connect.Error assert.True(t, errors.As(err, &connectErr)) - t.Log(err) assert.Equal(t, connectErr.Code(), connect.CodeUnimplemented) assert.True( t, @@ -1995,7 +1994,7 @@ func TestBidiOverHTTP1(t *testing.T) { _, err := stream.Receive() assert.NotNil(t, err) assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown) - assert.True(t, strings.HasSuffix(err.Error(), "HTTP status 505 HTTP Version Not Supported")) + assert.Equal(t, err.Error(), "unknown: HTTP status 505 HTTP Version Not Supported") assert.Nil(t, stream.CloseRequest()) assert.Nil(t, stream.CloseResponse()) } diff --git a/duplex_http_call.go b/duplex_http_call.go index 4a46c4ff..e215140b 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -42,10 +42,14 @@ type duplexHTTPCall struct { requestBodyReader *io.PipeReader requestBodyWriter *io.PipeWriter - requestSent bool - responseReady sync.WaitGroup - request *http.Request - response *http.Response + sendRequestOnce sync.Once + request *http.Request + response *http.Response + + // responseReady is closed when the response is ready or when the request + // fails. Any error on request initialisation will be set on the + // responseErr. There's always a response if responseErr is nil. + responseReady chan struct{} responseErr error } @@ -78,16 +82,15 @@ func newDuplexHTTPCall( Body: pipeReader, Host: url.Host, }).WithContext(ctx) - call := &duplexHTTPCall{ + return &duplexHTTPCall{ ctx: ctx, httpClient: httpClient, streamType: spec.StreamType, requestBodyReader: pipeReader, requestBodyWriter: pipeWriter, request: request, + responseReady: make(chan struct{}), } - call.responseReady.Add(1) - return call } // Write to the request body. @@ -170,7 +173,7 @@ func (d *duplexHTTPCall) Read(data []byte) (int, error) { } func (d *duplexHTTPCall) CloseRead() error { - d.responseReady.Wait() + _ = d.BlockUntilResponseReady() if d.response == nil { return nil } @@ -191,7 +194,7 @@ func (d *duplexHTTPCall) ResponseStatusCode() (int, error) { // ResponseHeader returns the response HTTP headers. func (d *duplexHTTPCall) ResponseHeader() http.Header { - d.responseReady.Wait() + _ = d.BlockUntilResponseReady() if d.response != nil { return d.response.Header } @@ -200,7 +203,7 @@ func (d *duplexHTTPCall) ResponseHeader() http.Header { // ResponseTrailer returns the response HTTP trailers. func (d *duplexHTTPCall) ResponseTrailer() http.Header { - d.responseReady.Wait() + _ = d.BlockUntilResponseReady() if d.response != nil { return d.response.Trailer } @@ -216,7 +219,7 @@ func (d *duplexHTTPCall) SetValidateResponse(validate func(*http.Response) *Erro // BlockUntilResponseReady returns when the response is ready or reports an // error from initializing the request. func (d *duplexHTTPCall) BlockUntilResponseReady() error { - d.responseReady.Wait() + <-d.responseReady return d.responseErr } @@ -224,17 +227,15 @@ func (d *duplexHTTPCall) BlockUntilResponseReady() error { // It is not safe to call this concurrently. Write and CloseWrite call this but // ensure that they're not called concurrently. func (d *duplexHTTPCall) ensureRequestMade() { - if d.requestSent { - return // already sent - } - d.requestSent = true - go d.makeRequest() + d.sendRequestOnce.Do(func() { + go d.makeRequest() + }) } func (d *duplexHTTPCall) makeRequest() { // This runs concurrently with Write and CloseWrite. Read and CloseRead wait // on d.responseReady, so we can't race with them. - defer d.responseReady.Done() + defer close(d.responseReady) // Promote the header Host to the request object. if host := d.request.Header.Get(headerHost); len(host) > 0 { From 628915a555b998994e79f8a78b9a583e6ebf2756 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 23 Oct 2023 12:47:01 -0400 Subject: [PATCH 16/21] Add comment to clarify behaviour --- connect_ext_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/connect_ext_test.go b/connect_ext_test.go index 1eaca691..4645134f 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -774,6 +774,7 @@ func TestBidiRequiresHTTP2(t *testing.T) { server.URL(), ) stream := client.CumSum(context.Background()) + // Stream creates an async request, can error on Send or Receive. if err := stream.Send(&pingv1.CumSumRequest{}); err != nil { assert.ErrorIs(t, err, io.EOF) } From 23dc2dae9871464ff47521548c5a29f41596ef5b Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 23 Oct 2023 14:47:29 -0400 Subject: [PATCH 17/21] Restrict log errors to New|Logger|Lshortfile --- .golangci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.golangci.yml b/.golangci.yml index cf5c0763..8575067f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -122,6 +122,7 @@ issues: # We want to set http.Server's logger - linters: [forbidigo] path: internal/memhttp + text: "use of `log.(New|Logger|Lshortfile)` forbidden by pattern .*" # We want to show examples with http.Get - linters: [noctx] path: internal/memhttp/memhttp_test.go From 2ee0deffa28e91f93c5be4855ac754b75bccae04 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 25 Oct 2023 13:03:32 -0400 Subject: [PATCH 18/21] Document response close error handling --- client_ext_test.go | 68 ++++++++++++++++++++++++--------------------- duplex_http_call.go | 4 +-- 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index 110edb1b..ce799958 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -86,41 +86,45 @@ func TestClientPeer(t *testing.T) { connect.WithInterceptors(&assertPeerInterceptor{t}), ) ctx := context.Background() - // unary - unaryReq := connect.NewRequest[pingv1.PingRequest](nil) - _, err := client.Ping(ctx, unaryReq) - assert.Nil(t, err) - assert.Equal(t, unaryHTTPMethod, unaryReq.HTTPMethod()) - text := strings.Repeat(".", 256) - r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) - assert.Nil(t, err) - assert.Equal(t, r.Msg.Text, text) - // client streaming - clientStream := client.Sum(ctx) - t.Cleanup(func() { - _, closeErr := clientStream.CloseAndReceive() - assert.Nil(t, closeErr) + t.Run("unary", func(t *testing.T) { + unaryReq := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, unaryReq) + assert.Nil(t, err) + assert.Equal(t, unaryHTTPMethod, unaryReq.HTTPMethod()) + text := strings.Repeat(".", 256) + r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) + assert.Nil(t, err) + assert.Equal(t, r.Msg.Text, text) }) - assert.NotZero(t, clientStream.Peer().Addr) - assert.NotZero(t, clientStream.Peer().Protocol) - err = clientStream.Send(&pingv1.SumRequest{}) - assert.Nil(t, err) - // server streaming - serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1})) - t.Cleanup(func() { - assert.Nil(t, serverStream.Close()) + t.Run("client_stream", func(t *testing.T) { + clientStream := client.Sum(ctx) + t.Cleanup(func() { + _, closeErr := clientStream.CloseAndReceive() + assert.Nil(t, closeErr) + }) + assert.NotZero(t, clientStream.Peer().Addr) + assert.NotZero(t, clientStream.Peer().Protocol) + err := clientStream.Send(&pingv1.SumRequest{}) + assert.Nil(t, err) }) - assert.Nil(t, err) - // bidi streaming - bidiStream := client.CumSum(ctx) - t.Cleanup(func() { - assert.Nil(t, bidiStream.CloseRequest()) - assert.Nil(t, bidiStream.CloseResponse()) + t.Run("server_stream", func(t *testing.T) { + serverStream, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{})) + t.Cleanup(func() { + assert.Nil(t, serverStream.Close()) + }) + assert.Nil(t, err) + }) + t.Run("bidi_stream", func(t *testing.T) { + bidiStream := client.CumSum(ctx) + t.Cleanup(func() { + assert.Nil(t, bidiStream.CloseRequest()) + assert.Nil(t, bidiStream.CloseResponse()) + }) + assert.NotZero(t, bidiStream.Peer().Addr) + assert.NotZero(t, bidiStream.Peer().Protocol) + err := bidiStream.Send(&pingv1.CumSumRequest{}) + assert.Nil(t, err) }) - assert.NotZero(t, bidiStream.Peer().Addr) - assert.NotZero(t, bidiStream.Peer().Protocol) - err = bidiStream.Send(&pingv1.CumSumRequest{}) - assert.Nil(t, err) } t.Run("connect", func(t *testing.T) { diff --git a/duplex_http_call.go b/duplex_http_call.go index e215140b..e1f91026 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -260,10 +260,11 @@ func (d *duplexHTTPCall) makeRequest() { d.requestBodyReader.CloseWithError(io.EOF) return } + // We've got a response. We can now read from the response body. + // Closing the response body is delegated to the caller. d.response = response if err := d.validateResponse(response); err != nil { d.responseErr = err - d.response.Body.Close() d.requestBodyReader.CloseWithError(io.EOF) return } @@ -277,7 +278,6 @@ func (d *duplexHTTPCall) makeRequest() { response.ProtoMajor, response.ProtoMinor, ) - d.response.Body.Close() d.requestBodyReader.CloseWithError(io.EOF) } } From 00d7f6a787fdb371c380d010d6b11092eefb051a Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 25 Oct 2023 14:27:52 -0400 Subject: [PATCH 19/21] Fix and document Close behaviour for pipe --- duplex_http_call.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/duplex_http_call.go b/duplex_http_call.go index e1f91026..86bd375a 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -247,6 +247,9 @@ func (d *duplexHTTPCall) makeRequest() { } // Once we send a message to the server, they send a message back and // establish the receive side of the stream. + // On error, we close the request body using the Write side of the pipe. + // This ensures HTTP2 streams receive an io.EOF from the Read side of the + // pipe. Write's check for io.ErrClosedPipe and will convert this to io.EOF. response, err := d.httpClient.Do(d.request) //nolint:bodyclose if err != nil { err = wrapIfContextError(err) @@ -257,15 +260,15 @@ func (d *duplexHTTPCall) makeRequest() { err = NewError(CodeUnavailable, err) } d.responseErr = err - d.requestBodyReader.CloseWithError(io.EOF) + d.requestBodyWriter.Close() return } // We've got a response. We can now read from the response body. - // Closing the response body is delegated to the caller. + // Closing the response body is delegated to the caller even on error. d.response = response if err := d.validateResponse(response); err != nil { d.responseErr = err - d.requestBodyReader.CloseWithError(io.EOF) + d.requestBodyWriter.Close() return } if (d.streamType&StreamTypeBidi) == StreamTypeBidi && response.ProtoMajor < 2 { @@ -278,7 +281,7 @@ func (d *duplexHTTPCall) makeRequest() { response.ProtoMajor, response.ProtoMinor, ) - d.requestBodyReader.CloseWithError(io.EOF) + d.requestBodyWriter.Close() } } From 758f889c3973461d303252c5c7af7c289d95c430 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 25 Oct 2023 14:30:19 -0400 Subject: [PATCH 20/21] Move responseBodyReady to group what it protects --- duplex_http_call.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/duplex_http_call.go b/duplex_http_call.go index 86bd375a..4e0407ee 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -42,14 +42,15 @@ type duplexHTTPCall struct { requestBodyReader *io.PipeReader requestBodyWriter *io.PipeWriter + // sendRequestOnce ensures we only send the request once. sendRequestOnce sync.Once request *http.Request - response *http.Response // responseReady is closed when the response is ready or when the request // fails. Any error on request initialisation will be set on the // responseErr. There's always a response if responseErr is nil. responseReady chan struct{} + response *http.Response responseErr error } From fa4a5541d8a315f4e12c82cdb8fd1f991147d07f Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 25 Oct 2023 16:09:53 -0400 Subject: [PATCH 21/21] Add CloseResponse checks in TestServer --- connect_ext_test.go | 19 ++++++++++++++++++- duplex_http_call.go | 4 +++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/connect_ext_test.go b/connect_ext_test.go index 4645134f..af541259 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -183,14 +183,28 @@ func TestServer(t *testing.T) { connect.CodeOf(stream.Err()), connect.CodeInvalidArgument, ) + assert.Nil(t, stream.Close()) }) t.Run("count_up_timeout", func(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) - defer cancel() + t.Cleanup(cancel) _, err := client.CountUp(ctx, connect.NewRequest(&pingv1.CountUpRequest{Number: 1})) assert.NotNil(t, err) assert.Equal(t, connect.CodeOf(err), connect.CodeDeadlineExceeded) }) + t.Run("count_up_cancel_after_first_response", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + request := connect.NewRequest(&pingv1.CountUpRequest{Number: 5}) + request.Header().Set(clientHeader, headerValue) + stream, err := client.CountUp(ctx, request) + assert.Nil(t, err) + assert.True(t, stream.Receive()) + cancel() + assert.False(t, stream.Receive()) + assert.NotNil(t, stream.Err()) + assert.Equal(t, connect.CodeOf(stream.Err()), connect.CodeCanceled) + assert.Nil(t, stream.Close()) + }) } testCumSum := func(t *testing.T, client pingv1connect.PingServiceClient, expectSuccess bool) { //nolint:thelper t.Run("cumsum", func(t *testing.T) { @@ -285,6 +299,7 @@ func TestServer(t *testing.T) { assert.Equal(t, connect.CodeOf(err), connect.CodeCanceled) assert.Equal(t, got, expect) assert.False(t, connect.IsWireError(err)) + assert.Nil(t, stream.CloseResponse()) }) t.Run("cumsum_cancel_before_send", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -302,6 +317,8 @@ func TestServer(t *testing.T) { err := stream.Send(&pingv1.CumSumRequest{Number: 19}) assert.Equal(t, connect.CodeOf(err), connect.CodeCanceled, assert.Sprintf("%v", err)) assert.False(t, connect.IsWireError(err)) + assert.Nil(t, stream.CloseRequest()) + assert.Nil(t, stream.CloseResponse()) }) } testErrors := func(t *testing.T, client pingv1connect.PingServiceClient) { //nolint:thelper diff --git a/duplex_http_call.go b/duplex_http_call.go index 4e0407ee..7181dd65 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -178,7 +178,9 @@ func (d *duplexHTTPCall) CloseRead() error { if d.response == nil { return nil } - if _, err := discard(d.response.Body); err != nil { + if _, err := discard(d.response.Body); err != nil && + !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { _ = d.response.Body.Close() return wrapIfRSTError(err) }