diff --git a/const.go b/const.go index c1a2deb..e737d85 100644 --- a/const.go +++ b/const.go @@ -3,6 +3,7 @@ package yamux import ( "encoding/binary" "fmt" + "time" ) const ( @@ -52,6 +53,7 @@ const ( // It's not an implementation choice, the value defined in the specification. initialStreamWindow = 256 * 1024 maxStreamWindow = 16 * 1024 * 1024 + goAwayWaitTime = 5 * time.Second ) const ( diff --git a/session.go b/session.go index c4cd1bd..62ea2c3 100644 --- a/session.go +++ b/session.go @@ -284,8 +284,14 @@ func (s *Session) AcceptStream() (*Stream, error) { } // Close is used to close the session and all streams. -// Attempts to send a GoAway before closing the connection. +// Attempts to send a GoAway before closing the connection. The GoAway may not actually be sent depending on the +// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or +// if there's unread data in the kernel receive buffer. func (s *Session) Close() error { + return s.close(true, goAwayNormal) +} + +func (s *Session) close(sendGoAway bool, errCode uint32) error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() @@ -297,10 +303,21 @@ func (s *Session) Close() error { s.shutdownErr = ErrSessionShutdown } close(s.shutdownCh) - s.conn.Close() s.stopKeepalive() - <-s.recvDoneCh + + // wait for write loop to exit + _ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked <-s.sendDoneCh + if sendGoAway { + ga := s.goAway(errCode) + if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { + _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here + } + } + + s.conn.SetWriteDeadline(time.Time{}) + s.conn.Close() + <-s.recvDoneCh s.streamLock.Lock() defer s.streamLock.Unlock() @@ -320,7 +337,7 @@ func (s *Session) exitErr(err error) { s.shutdownErr = err } s.shutdownLock.Unlock() - s.Close() + s.close(false, 0) } // GoAway can be used to prevent accepting further