Skip to content

Commit

Permalink
server: cleanup graceful shutdown routine
Browse files Browse the repository at this point in the history
Signed-off-by: He Xian <hexian000@outlook.com>
  • Loading branch information
hexian000 committed Dec 30, 2024
1 parent d0429b4 commit 1445c68
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 30 deletions.
5 changes: 1 addition & 4 deletions v3/forwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@ type forwarder struct {
g routines.Group
conn map[net.Conn]struct{}
counter chan struct{}
closeCh chan struct{}
}

func New(maxConn int, g routines.Group) Forwarder {
return &forwarder{
conn: make(map[net.Conn]struct{}),
counter: make(chan struct{}, maxConn),
closeCh: make(chan struct{}),
g: g,
}
}
Expand Down Expand Up @@ -79,7 +77,7 @@ func (f *forwarder) connCopy(dst net.Conn, src net.Conn) {

func (f *forwarder) Forward(accepted net.Conn, dialed net.Conn) error {
select {
case <-f.closeCh:
case <-f.g.CloseC():
return routines.ErrClosed
case f.counter <- struct{}{}:
default:
Expand Down Expand Up @@ -113,7 +111,6 @@ func (f *forwarder) Count() int {
}

func (f *forwarder) Close() {
close(f.closeCh)
f.mu.Lock()
defer f.mu.Unlock()
for conn := range f.conn {
Expand Down
41 changes: 15 additions & 26 deletions v3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ func (s *Server) getAllTunnels() []*tunnel {
return tunnels
}

func (s *Server) stopAllTunnels() {
s.tunnelsMu.RLock()
defer s.tunnelsMu.RUnlock()
for _, t := range s.tunnels {
_ = t.Stop()
}
}

type ServerStats struct {
NumSessions int
NumStreams int
Expand Down Expand Up @@ -275,12 +267,7 @@ func (s *Server) loadTunnels(cfg *config.File) error {
if tuncfg.Disabled {
continue
}
t := &tunnel{
peerName: name, s: s,
mux: make(map[*yamux.Session]string),
closeSig: make(chan struct{}, 1),
redialSig: make(chan struct{}, 1),
}
t := newTunnel(name, s)
s.tunnels[name] = t
if err := t.Start(); err != nil {
return err
Expand Down Expand Up @@ -337,17 +324,9 @@ func (s *Server) Start() error {
return nil
}

func (s *Server) closeAllMux() {
s.muxMu.Lock()
defer s.muxMu.Unlock()
for mux := range s.mux {
ioClose(mux)
delete(s.mux, mux)
}
}

// Shutdown gracefully
func (s *Server) Shutdown() error {
// stop all listeners
if s.l != nil {
ioClose(s.l)
s.l = nil
Expand All @@ -356,11 +335,21 @@ func (s *Server) Shutdown() error {
ioClose(s.apiListener)
s.apiListener = nil
}
// cancel all contexts
s.ctx.close()
s.stopAllTunnels()
s.closeAllMux()
s.f.Close()
// stop all tunnels
s.g.Close()
// close all mux
func() {
s.muxMu.Lock()
defer s.muxMu.Unlock()
for mux := range s.mux {
ioClose(mux)
delete(s.mux, mux)
}
}()
// close all forwards
s.f.Close()
slog.Info("waiting for unfinished connections")
s.g.Wait()
return nil
Expand Down
9 changes: 9 additions & 0 deletions v3/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ type tunnel struct {
lastChanged time.Time
}

func newTunnel(peerName string, s *Server) *tunnel {
return &tunnel{
peerName: peerName, s: s,
mux: make(map[*yamux.Session]string),
closeSig: make(chan struct{}),
redialSig: make(chan struct{}, 1),
}
}

func (t *tunnel) getConfig() (*config.File, *tls.Config, *config.Tunnel) {
cfg, tlscfg := t.s.getConfig()
return cfg, tlscfg, cfg.GetTunnel(t.peerName)
Expand Down

0 comments on commit 1445c68

Please sign in to comment.