diff --git a/session.go b/session.go index 62ea2c3..bc5e1a3 100644 --- a/session.go +++ b/session.go @@ -102,11 +102,15 @@ type Session struct { // recvDoneCh is closed when recv() exits to avoid a race // between stream registration and stream shutdown recvDoneCh chan struct{} + // recvErr is the error the receive loop ended with + recvErr error // sendDoneCh is closed when send() exits to avoid a race // between returning from a Stream.Write and exiting from the send loop // (which may be reading a buffer on-load-from Stream.Write). sendDoneCh chan struct{} + // sendErr is the error the send loop ended with + sendErr error // client is true if we're the client and our stream IDs should be odd. client bool @@ -288,10 +292,18 @@ func (s *Session) AcceptStream() (*Stream, error) { // semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or // if there's unread data in the kernel receive buffer. func (s *Session) Close() error { - return s.close(true, goAwayNormal) + return s.closeWithGoAway(goAwayNormal) } -func (s *Session) close(sendGoAway bool, errCode uint32) error { +// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode. +// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn. +// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel +// receive buffer. +func (s *Session) CloseWithError(errCode uint32) error { + return s.closeWithGoAway(errCode) +} + +func (s *Session) closeWithGoAway(errCode uint32) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() @@ -308,14 +320,12 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error { // wait for write loop to exit _ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked <-s.sendDoneCh - if sendGoAway { - ga := s.goAway(errCode) - if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { - _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here - } + ga := s.goAway(errCode) + if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { + _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here } - s.conn.SetWriteDeadline(time.Time{}) + s.conn.Close() <-s.recvDoneCh @@ -329,15 +339,37 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error { return nil } -// exitErr is used to handle an error that is causing the -// session to terminate. -func (s *Session) exitErr(err error) { +func (s *Session) closeWithoutGoAway(err error) error { s.shutdownLock.Lock() + defer s.shutdownLock.Unlock() + if s.shutdown { + return nil + } + s.shutdown = true if s.shutdownErr == nil { s.shutdownErr = err } - s.shutdownLock.Unlock() - s.close(false, 0) + + s.conn.Close() + <-s.recvDoneCh + // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code + // received in a GoAway frame received just before the RST that closed the sendLoop + if _, ok := s.recvErr.(*GoAwayError); ok { + s.shutdownErr = s.recvErr + } + close(s.shutdownCh) + + s.stopKeepalive() + <-s.sendDoneCh + + s.streamLock.Lock() + defer s.streamLock.Unlock() + for id, stream := range s.streams { + stream.forceClose() + delete(s.streams, id) + stream.memorySpan.Done() + } + return nil } // GoAway can be used to prevent accepting further @@ -468,7 +500,12 @@ func (s *Session) startKeepalive() { if err != nil { s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) - s.exitErr(ErrKeepAliveTimeout) + s.shutdownLock.Lock() + if s.shutdownErr == nil { + s.shutdownErr = ErrKeepAliveTimeout + } + s.shutdownLock.Unlock() + s.closeWithGoAway(goAwayNormal) } }) } @@ -533,7 +570,7 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err // send is a long running goroutine that sends data func (s *Session) send() { if err := s.sendLoop(); err != nil { - s.exitErr(err) + s.closeWithoutGoAway(err) } } @@ -661,7 +698,7 @@ func (s *Session) sendLoop() (err error) { // recv is a long running goroutine that accepts new data func (s *Session) recv() { if err := s.recvLoop(); err != nil { - s.exitErr(err) + s.closeWithoutGoAway(err) } } @@ -683,7 +720,10 @@ func (s *Session) recvLoop() (err error) { err = fmt.Errorf("panic in yamux receive loop: %s", rerr) } }() - defer close(s.recvDoneCh) + defer func() { + s.recvErr = err + close(s.recvDoneCh) + }() var hdr header for { // fmt.Printf("ReadFull from %#v\n", s.reader) @@ -799,17 +839,17 @@ func (s *Session) handleGoAway(hdr header) error { switch code { case goAwayNormal: atomic.SwapInt32(&s.remoteGoAway, 1) + // Don't close connection on normal go away. Let the existing streams + // complete gracefully. + return nil case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") - return fmt.Errorf("yamux protocol error") case goAwayInternalErr: s.logger.Printf("[ERR] yamux: received internal error go away") - return fmt.Errorf("remote yamux internal error") default: - s.logger.Printf("[ERR] yamux: received unexpected go away") - return fmt.Errorf("unexpected go away received") + s.logger.Printf("[ERR] yamux: received go away with error code: %d", code) } - return nil + return &GoAwayError{Remote: true, ErrorCode: code} } // incomingStream is used to create a new incoming stream diff --git a/session_test.go b/session_test.go index 974b6d5..cf15790 100644 --- a/session_test.go +++ b/session_test.go @@ -3,6 +3,7 @@ package yamux import ( "bytes" "context" + "errors" "fmt" "io" "math/rand" @@ -650,6 +651,35 @@ func TestGoAway(t *testing.T) { default: t.Fatalf("err: %v", err) } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("expected GoAway error") +} + +func TestCloseWithError(t *testing.T) { + // This test is noisy. + conf := testConf() + conf.LogOutput = io.Discard + + client, server := testClientServerConfig(conf) + defer client.Close() + defer server.Close() + + if err := server.CloseWithError(42); err != nil { + t.Fatalf("err: %v", err) + } + + for i := 0; i < 100; i++ { + s, err := client.Open(context.Background()) + if err == nil { + s.Close() + time.Sleep(50 * time.Millisecond) + continue + } + if !errors.Is(err, &GoAwayError{ErrorCode: 42, Remote: true}) { + t.Fatalf("err: %v", err) + } + return } t.Fatalf("expected GoAway error") }