diff --git a/internal/protocol/tcp_server.go b/internal/protocol/tcp_server.go index 186b2965c..8cecddeb5 100644 --- a/internal/protocol/tcp_server.go +++ b/internal/protocol/tcp_server.go @@ -5,6 +5,7 @@ import ( "net" "runtime" "strings" + "sync" "github.com/nsqio/nsq/internal/lg" ) @@ -16,6 +17,8 @@ type TCPHandler interface { func TCPServer(listener net.Listener, handler TCPHandler, logf lg.AppLogFunc) error { logf(lg.INFO, "TCP: listening on %s", listener.Addr()) + var wg sync.WaitGroup + for { clientConn, err := listener.Accept() if err != nil { @@ -30,9 +33,17 @@ func TCPServer(listener net.Listener, handler TCPHandler, logf lg.AppLogFunc) er } break } - go handler.Handle(clientConn) + + wg.Add(1) + go func() { + handler.Handle(clientConn) + wg.Done() + }() } + // wait to return until all handler goroutines complete + wg.Wait() + logf(lg.INFO, "TCP: closing %s", listener.Addr()) return nil diff --git a/nsqlookupd/lookup_protocol_v1_test.go b/nsqlookupd/lookup_protocol_v1_test.go index 5ff524bd9..b3cfca60d 100644 --- a/nsqlookupd/lookup_protocol_v1_test.go +++ b/nsqlookupd/lookup_protocol_v1_test.go @@ -40,6 +40,8 @@ func testIOLoopReturnsClientErr(t *testing.T, fakeConn test.FakeNetConn) { test.Nil(t, err) prot := &LookupProtocolV1{ctx: &Context{nsqlookupd: nsqlookupd}} + nsqlookupd.tcpServer = &tcpServer{ctx: prot.ctx} + errChan := make(chan error) testIOLoop := func() { errChan <- prot.IOLoop(fakeConn) diff --git a/nsqlookupd/nsqlookupd.go b/nsqlookupd/nsqlookupd.go index 11dc3b547..72c3ad4f5 100644 --- a/nsqlookupd/nsqlookupd.go +++ b/nsqlookupd/nsqlookupd.go @@ -18,6 +18,7 @@ type NSQLookupd struct { opts *Options tcpListener net.Listener httpListener net.Listener + tcpServer *tcpServer waitGroup util.WaitGroupWrapper DB *RegistrationDB } @@ -63,9 +64,9 @@ func (l *NSQLookupd) Main() error { }) } - tcpServer := &tcpServer{ctx: ctx} + l.tcpServer = &tcpServer{ctx: ctx} l.waitGroup.Wrap(func() { - exitFunc(protocol.TCPServer(l.tcpListener, tcpServer, l.logf)) + exitFunc(protocol.TCPServer(l.tcpListener, l.tcpServer, l.logf)) }) httpServer := newHTTPServer(ctx) l.waitGroup.Wrap(func() { @@ -89,6 +90,10 @@ func (l *NSQLookupd) Exit() { l.tcpListener.Close() } + if l.tcpServer != nil { + l.tcpServer.CloseAll() + } + if l.httpListener != nil { l.httpListener.Close() } diff --git a/nsqlookupd/tcp.go b/nsqlookupd/tcp.go index 0050762aa..24b00de89 100644 --- a/nsqlookupd/tcp.go +++ b/nsqlookupd/tcp.go @@ -3,12 +3,14 @@ package nsqlookupd import ( "io" "net" + "sync" "github.com/nsqio/nsq/internal/protocol" ) type tcpServer struct { - ctx *Context + ctx *Context + conns sync.Map } func (p *tcpServer) Handle(clientConn net.Conn) { @@ -41,9 +43,19 @@ func (p *tcpServer) Handle(clientConn net.Conn) { return } + p.conns.Store(clientConn.RemoteAddr(), clientConn) + err = prot.IOLoop(clientConn) if err != nil { p.ctx.nsqlookupd.logf(LOG_ERROR, "client(%s) - %s", clientConn.RemoteAddr(), err) - return } + + p.conns.Delete(clientConn.RemoteAddr()) +} + +func (p *tcpServer) CloseAll() { + p.conns.Range(func(k, v interface{}) bool { + v.(net.Conn).Close() + return true + }) }