diff --git a/v3/forwarder/forwarder.go b/v3/forwarder/forwarder.go index 2fbe7f3..6f95c86 100644 --- a/v3/forwarder/forwarder.go +++ b/v3/forwarder/forwarder.go @@ -29,14 +29,12 @@ type forwarder struct { g routines.Group conn map[net.Conn]struct{} counter chan struct{} - closeCh chan struct{} } func New(maxConn int, g routines.Group) Forwarder { return &forwarder{ conn: make(map[net.Conn]struct{}), counter: make(chan struct{}, maxConn), - closeCh: make(chan struct{}), g: g, } } @@ -79,7 +77,7 @@ func (f *forwarder) connCopy(dst net.Conn, src net.Conn) { func (f *forwarder) Forward(accepted net.Conn, dialed net.Conn) error { select { - case <-f.closeCh: + case <-f.g.CloseC(): return routines.ErrClosed case f.counter <- struct{}{}: default: @@ -113,7 +111,6 @@ func (f *forwarder) Count() int { } func (f *forwarder) Close() { - close(f.closeCh) f.mu.Lock() defer f.mu.Unlock() for conn := range f.conn { diff --git a/v3/server.go b/v3/server.go index 0edb007..3c5e1cc 100644 --- a/v3/server.go +++ b/v3/server.go @@ -107,14 +107,6 @@ func (s *Server) getAllTunnels() []*tunnel { return tunnels } -func (s *Server) stopAllTunnels() { - s.tunnelsMu.RLock() - defer s.tunnelsMu.RUnlock() - for _, t := range s.tunnels { - _ = t.Stop() - } -} - type ServerStats struct { NumSessions int NumStreams int @@ -275,12 +267,7 @@ func (s *Server) loadTunnels(cfg *config.File) error { if tuncfg.Disabled { continue } - t := &tunnel{ - peerName: name, s: s, - mux: make(map[*yamux.Session]string), - closeSig: make(chan struct{}, 1), - redialSig: make(chan struct{}, 1), - } + t := newTunnel(name, s) s.tunnels[name] = t if err := t.Start(); err != nil { return err @@ -337,17 +324,9 @@ func (s *Server) Start() error { return nil } -func (s *Server) closeAllMux() { - s.muxMu.Lock() - defer s.muxMu.Unlock() - for mux := range s.mux { - ioClose(mux) - delete(s.mux, mux) - } -} - // Shutdown gracefully func (s *Server) Shutdown() error { + // stop all listeners if s.l != nil { ioClose(s.l) s.l = nil @@ -356,11 +335,21 @@ func (s *Server) Shutdown() error { ioClose(s.apiListener) s.apiListener = nil } + // cancel all contexts s.ctx.close() - s.stopAllTunnels() - s.closeAllMux() - s.f.Close() + // stop all tunnels s.g.Close() + // close all mux + func() { + s.muxMu.Lock() + defer s.muxMu.Unlock() + for mux := range s.mux { + ioClose(mux) + delete(s.mux, mux) + } + }() + // close all forwards + s.f.Close() slog.Info("waiting for unfinished connections") s.g.Wait() return nil diff --git a/v3/tunnel.go b/v3/tunnel.go index 4945237..b35faea 100644 --- a/v3/tunnel.go +++ b/v3/tunnel.go @@ -34,6 +34,15 @@ type tunnel struct { lastChanged time.Time } +func newTunnel(peerName string, s *Server) *tunnel { + return &tunnel{ + peerName: peerName, s: s, + mux: make(map[*yamux.Session]string), + closeSig: make(chan struct{}), + redialSig: make(chan struct{}, 1), + } +} + func (t *tunnel) getConfig() (*config.File, *tls.Config, *config.Tunnel) { cfg, tlscfg := t.s.getConfig() return cfg, tlscfg, cfg.GetTunnel(t.peerName)