From e83b1e66ec295c52e0d84fed92e865e935fe8e94 Mon Sep 17 00:00:00 2001 From: "Benjamin A. Stockwell" Date: Wed, 18 Sep 2019 17:23:00 -0500 Subject: [PATCH] nsqlookupd: synchronize goroutines to avoid t.Log races --- internal/protocol/tcp_server.go | 14 +++++++++++++- nsqlookupd/lookup_protocol_v1.go | 28 ++++++++++++++++++++++++--- nsqlookupd/lookup_protocol_v1_test.go | 2 +- nsqlookupd/nsqlookupd.go | 9 ++++++--- nsqlookupd/tcp.go | 5 +++-- 5 files changed, 48 insertions(+), 10 deletions(-) diff --git a/internal/protocol/tcp_server.go b/internal/protocol/tcp_server.go index 186b2965c..0a686035a 100644 --- a/internal/protocol/tcp_server.go +++ b/internal/protocol/tcp_server.go @@ -5,6 +5,7 @@ import ( "net" "runtime" "strings" + "sync" "github.com/nsqio/nsq/internal/lg" ) @@ -16,6 +17,8 @@ type TCPHandler interface { func TCPServer(listener net.Listener, handler TCPHandler, logf lg.AppLogFunc) error { logf(lg.INFO, "TCP: listening on %s", listener.Addr()) + var wg sync.WaitGroup + for { clientConn, err := listener.Accept() if err != nil { @@ -30,9 +33,18 @@ func TCPServer(listener net.Listener, handler TCPHandler, logf lg.AppLogFunc) er } break } - go handler.Handle(clientConn) + + wg.Add(1) + + go func() { + defer wg.Done() + handler.Handle(clientConn) + }() } + // wait to return until all handler goroutines complete + wg.Wait() + logf(lg.INFO, "TCP: closing %s", listener.Addr()) return nil diff --git a/nsqlookupd/lookup_protocol_v1.go b/nsqlookupd/lookup_protocol_v1.go index dbeb78c2a..552221f7a 100644 --- a/nsqlookupd/lookup_protocol_v1.go +++ b/nsqlookupd/lookup_protocol_v1.go @@ -18,7 +18,8 @@ import ( ) type LookupProtocolV1 struct { - ctx *Context + ctx *Context + exitChan chan struct{} } func (p *LookupProtocolV1) IOLoop(conn net.Conn) error { @@ -27,9 +28,30 @@ func (p *LookupProtocolV1) IOLoop(conn net.Conn) error { client := NewClientV1(conn) reader := bufio.NewReader(client) + readChan := make(chan string) + errChan := make(chan error) + for { - line, err = reader.ReadString('\n') - if err != nil { + exitLoop := false + + // do this read in a goroutine so we can exit this loop if needed + go func() { + line, e := reader.ReadString('\n') + readChan <- line + errChan <- e + }() + + // now wait until we either get an exit signal or a readline completes + select { + case <-p.exitChan: + exitLoop = true + break + case line = <-readChan: + err = <-errChan + break + } + + if exitLoop || err != nil { break } diff --git a/nsqlookupd/lookup_protocol_v1_test.go b/nsqlookupd/lookup_protocol_v1_test.go index 5ff524bd9..19e43a239 100644 --- a/nsqlookupd/lookup_protocol_v1_test.go +++ b/nsqlookupd/lookup_protocol_v1_test.go @@ -38,7 +38,7 @@ func testIOLoopReturnsClientErr(t *testing.T, fakeConn test.FakeNetConn) { nsqlookupd, err := New(opts) test.Nil(t, err) - prot := &LookupProtocolV1{ctx: &Context{nsqlookupd: nsqlookupd}} + prot := &LookupProtocolV1{ctx: &Context{nsqlookupd: nsqlookupd}, exitChan: make(chan struct{})} errChan := make(chan error) testIOLoop := func() { diff --git a/nsqlookupd/nsqlookupd.go b/nsqlookupd/nsqlookupd.go index 11dc3b547..adf231b80 100644 --- a/nsqlookupd/nsqlookupd.go +++ b/nsqlookupd/nsqlookupd.go @@ -20,6 +20,7 @@ type NSQLookupd struct { httpListener net.Listener waitGroup util.WaitGroupWrapper DB *RegistrationDB + exitChan chan struct{} } func New(opts *Options) (*NSQLookupd, error) { @@ -29,8 +30,9 @@ func New(opts *Options) (*NSQLookupd, error) { opts.Logger = log.New(os.Stderr, opts.LogPrefix, log.Ldate|log.Ltime|log.Lmicroseconds) } l := &NSQLookupd{ - opts: opts, - DB: NewRegistrationDB(), + opts: opts, + DB: NewRegistrationDB(), + exitChan: make(chan struct{}), } l.logf(LOG_INFO, version.String("nsqlookupd")) @@ -63,7 +65,7 @@ func (l *NSQLookupd) Main() error { }) } - tcpServer := &tcpServer{ctx: ctx} + tcpServer := &tcpServer{ctx: ctx, exitChan: l.exitChan} l.waitGroup.Wrap(func() { exitFunc(protocol.TCPServer(l.tcpListener, tcpServer, l.logf)) }) @@ -85,6 +87,7 @@ func (l *NSQLookupd) RealHTTPAddr() *net.TCPAddr { } func (l *NSQLookupd) Exit() { + close(l.exitChan) if l.tcpListener != nil { l.tcpListener.Close() } diff --git a/nsqlookupd/tcp.go b/nsqlookupd/tcp.go index 0050762aa..ac071c7ae 100644 --- a/nsqlookupd/tcp.go +++ b/nsqlookupd/tcp.go @@ -8,7 +8,8 @@ import ( ) type tcpServer struct { - ctx *Context + ctx *Context + exitChan chan struct{} } func (p *tcpServer) Handle(clientConn net.Conn) { @@ -32,7 +33,7 @@ func (p *tcpServer) Handle(clientConn net.Conn) { var prot protocol.Protocol switch protocolMagic { case " V1": - prot = &LookupProtocolV1{ctx: p.ctx} + prot = &LookupProtocolV1{ctx: p.ctx, exitChan: p.exitChan} default: protocol.SendResponse(clientConn, []byte("E_BAD_PROTOCOL")) clientConn.Close()