Skip to content

Commit

Permalink
internal/transport/http2_server: properly convert ConnectionError to …
Browse files Browse the repository at this point in the history
…Unavailable status in WriteHeader
  • Loading branch information
msenarista committed Dec 26, 2023
1 parent 4f03f3f commit f67a784
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
7 changes: 6 additions & 1 deletion internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
}
}
if err := t.writeHeaderLocked(s); err != nil {
return status.Convert(err).Err()
switch e := err.(type) {
case ConnectionError:
return status.Error(codes.Unavailable, e.Desc)
default:
return status.Convert(err).Err()
}
}
return nil
}
Expand Down
64 changes: 64 additions & 0 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -2136,6 +2137,69 @@ func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) {
}
}

func (s) TestWriteHeaderConnectionError(t *testing.T) {
server, client, cancel := setUp(t, 0, notifyCall)
defer cancel()
defer server.stop()

waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()

if len(server.conns) == 0 {
return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
}
return false, nil
})

if len(server.conns) != 1 {
t.Fatal("Server must have an active connection for the client.")
}

// Get the server transfort for the connecton to the client
var serverTransport *http2Server
server.mu.Lock()
for k := range server.conns {
serverTransport = k.(*http2Server)
}
notifyChan := make(chan struct{})
server.h.notify = notifyChan
server.mu.Unlock()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cstream1, err := client.NewStream(ctx, &CallHdr{})
if err != nil {
t.Fatalf("Client failed to create first stream. Err: %v", err)
}

<-notifyChan // Wait server stream to be established
var sstream1 *Stream
// Access stream on the server
serverTransport.mu.Lock()
for _, v := range serverTransport.activeStreams {
if v.id == cstream1.id {
sstream1 = v
}
}
serverTransport.mu.Unlock()
if sstream1 == nil {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
}

client.Close(fmt.Errorf("closed manually by test"))

// Wait server transport to be closed
<-serverTransport.done

// Write header on a closed server transport
err = serverTransport.WriteHeader(sstream1, metadata.MD{})
st := status.Convert(err)
if st.Code() != codes.Unavailable {
t.Fatalf("Unailable status expected but got: %v", st.Code().String())
}
}

func (s) TestPingPong1B(t *testing.T) {
runPingPongTest(t, 1)
}
Expand Down

0 comments on commit f67a784

Please sign in to comment.