diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 680c9eba0b17..f6bac0e8a00d 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -960,7 +960,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } } if err := t.writeHeaderLocked(s); err != nil { - return status.Convert(err).Err() + switch e := err.(type) { + case ConnectionError: + return status.Error(codes.Unavailable, e.Desc) + default: + return status.Convert(err).Err() + } } return nil } diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 21aff27db1df..dabedc0bbfd9 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -45,6 +45,7 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" ) @@ -2136,6 +2137,69 @@ func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) { } } +func (s) TestWriteHeaderConnectionError(t *testing.T) { + server, client, cancel := setUp(t, 0, notifyCall) + defer cancel() + defer server.stop() + + waitWhileTrue(t, func() (bool, error) { + server.mu.Lock() + defer server.mu.Unlock() + + if len(server.conns) == 0 { + return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") + } + return false, nil + }) + + if len(server.conns) != 1 { + t.Fatal("Server must have an active connection for the client.") + } + + // Get the server transfort for the connecton to the client + var serverTransport *http2Server + server.mu.Lock() + for k := range server.conns { + serverTransport = k.(*http2Server) + } + notifyChan := make(chan struct{}) + server.h.notify = notifyChan + server.mu.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cstream1, err := client.NewStream(ctx, &CallHdr{}) + if err != nil { + t.Fatalf("Client failed to create first stream. Err: %v", err) + } + + <-notifyChan // Wait server stream to be established + var sstream1 *Stream + // Access stream on the server + serverTransport.mu.Lock() + for _, v := range serverTransport.activeStreams { + if v.id == cstream1.id { + sstream1 = v + } + } + serverTransport.mu.Unlock() + if sstream1 == nil { + t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) + } + + client.Close(fmt.Errorf("closed manually by test")) + + // Wait server transport to be closed + <-serverTransport.done + + // Write header on a closed server transport + err = serverTransport.WriteHeader(sstream1, metadata.MD{}) + st := status.Convert(err) + if st.Code() != codes.Unavailable { + t.Fatalf("Unailable status expected but got: %v", st.Code().String()) + } +} + func (s) TestPingPong1B(t *testing.T) { runPingPongTest(t, 1) }