Skip to content

Commit

Permalink
send GoAway on Close
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 26, 2024
1 parent e7338b0 commit d8cf4e7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
2 changes: 2 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package yamux
import (
"encoding/binary"
"fmt"
"time"
)

const (
Expand Down Expand Up @@ -52,6 +53,7 @@ const (
// It's not an implementation choice, the value defined in the specification.
initialStreamWindow = 256 * 1024
maxStreamWindow = 16 * 1024 * 1024
goAwayWaitTime = 5 * time.Second
)

const (
Expand Down
25 changes: 21 additions & 4 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,14 @@ func (s *Session) AcceptStream() (*Stream, error) {
}

// Close is used to close the session and all streams.
// Attempts to send a GoAway before closing the connection.
// Attempts to send a GoAway before closing the connection. The GoAway may not actually be sent depending on the
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
// if there's unread data in the kernel receive buffer.
func (s *Session) Close() error {
return s.close(true, goAwayNormal)
}

func (s *Session) close(sendGoAway bool, errCode uint32) error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()

Expand All @@ -297,10 +303,21 @@ func (s *Session) Close() error {
s.shutdownErr = ErrSessionShutdown
}
close(s.shutdownCh)
s.conn.Close()
s.stopKeepalive()
<-s.recvDoneCh

// wait for write loop to exit
_ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked
<-s.sendDoneCh
if sendGoAway {
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}
}

s.conn.SetWriteDeadline(time.Time{})
s.conn.Close()
<-s.recvDoneCh

s.streamLock.Lock()
defer s.streamLock.Unlock()
Expand All @@ -320,7 +337,7 @@ func (s *Session) exitErr(err error) {
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.Close()
s.close(false, 0)
}

// GoAway can be used to prevent accepting further
Expand Down

0 comments on commit d8cf4e7

Please sign in to comment.