From 3ce0910ac6006af72865a14bc59ea33b5781d60b Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Thu, 23 Feb 2023 21:58:11 +0000 Subject: [PATCH 1/6] fix: command response processing issues Fix command response processing issues introduced by notification handling. This includes: * Random hangs due to writes to channels with no readers. * Potential race when reusing client response buffer. * Close of notify channel while handler was still running. * Eliminated race on IsConnected check in ExecCmd. Also: * Use just \n for keep-alive to reduce data sent across the wire. * Wrap more errors to improve error reporting. * Remove io.(Reader|Writer) wrapping as that breaks scanner io.EOF handling. --- client.go | 183 +++++++++++++++++++++++++++++++++++++------------- connection.go | 11 +-- 2 files changed, 141 insertions(+), 53 deletions(-) diff --git a/client.go b/client.go index 2f769f3..f0e88c3 100644 --- a/client.go +++ b/client.go @@ -2,12 +2,12 @@ package ts3 import ( "bufio" - "errors" "fmt" "io" "net" "regexp" "strings" + "sync" "time" "golang.org/x/crypto/ssh" @@ -26,13 +26,18 @@ const ( // startBufSize is the initial size of allocation for the parse buffer. startBufSize = 4096 - // keepAliveData is the keepalive data. - keepAliveData = " \n" + // responseErrTimeout is the timeout use for sending response errors. + responseErrTimeout = time.Millisecond * 100 ) var ( + // respTrailerRe is the regexp which matches a server response to a command. respTrailerRe = regexp.MustCompile(`^error id=(\d+) msg=([^ ]+)(.*)`) + // keepAliveData is data which will be ignored by the server used to ensure + // the connection is kept alive. + keepAliveData = []byte("\n") + // DefaultTimeout is the default read / write / dial timeout for Clients. DefaultTimeout = 10 * time.Second @@ -52,6 +57,11 @@ type Connection interface { Connect(addr string, timeout time.Duration) error } +type response struct { + err error + lines []string +} + // Client is a TeamSpeak 3 ServerQuery client. type Client struct { conn Connection @@ -62,11 +72,13 @@ type Client struct { maxBufSize int notifyBufSize int work chan string - err chan error + response chan response notify chan Notification - disconnect chan struct{} - res []string + closing chan struct{} // closing is closed to indicate we're closing our connection. + done chan struct{} // done is closed once we're seen a fatal error. + doneOnce sync.Once connectHeader string + wg sync.WaitGroup Server *ServerMethods } @@ -151,8 +163,9 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) { maxBufSize: MaxParseTokenSize, notifyBufSize: DefaultNotifyBufSize, work: make(chan string), - err: make(chan error), - disconnect: make(chan struct{}), + response: make(chan response), + closing: make(chan struct{}), + done: make(chan struct{}), connectHeader: DefaultConnectHeader, } for _, f := range options { @@ -183,7 +196,7 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) { // Read the connection header if !c.scanner.Scan() { - return nil, c.scanErr() + return nil, fmt.Errorf("client: header: %w", c.scanErr()) } if l := c.scanner.Text(); l != c.connectHeader { @@ -192,7 +205,7 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) { // Slurp the banner if !c.scanner.Scan() { - return nil, c.scanErr() + return nil, fmt.Errorf("client: banner: %w", c.scanErr()) } if err := c.conn.SetReadDeadline(time.Time{}); err != nil { @@ -200,22 +213,56 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) { } // Start handlers + c.wg.Add(2) go c.messageHandler() go c.workHandler() return c, nil } +// fatalError returns false if err is nil otherwise it ensures +// that done is closed and returns true. +func (c *Client) fatalError(err error) bool { + if err == nil { + return false + } + + c.closeDone() + return true +} + +// closeDone safely closes c.done. +func (c *Client) closeDone() { + c.doneOnce.Do(func() { + close(c.done) + }) +} + // messageHandler scans incoming lines and handles them accordingly. +// - Notifications are sent to c.notify. +// - ExecCmd responses are sent to c.response. +// If a fatal error occurs it stops processing and exits. func (c *Client) messageHandler() { + defer c.wg.Done() + + buf := make([]string, 0, 10) for { if c.scanner.Scan() { line := c.scanner.Text() - //nolint: gocritic if line == "error id=0 msg=ok" { - c.err <- nil + var resp response + // Avoid creating a new buf if there was no data in the response. + if len(buf) > 0 { + resp.lines = buf + buf = make([]string, 0, 10) + } + c.response <- resp } else if matches := respTrailerRe.FindStringSubmatch(line); len(matches) == 4 { - c.err <- NewError(matches) + c.response <- response{err: NewError(matches)} + // Avoid creating a new buf if there was no data in the response. + if len(buf) > 0 { + buf = make([]string, 0, 10) + } } else if strings.Index(line, "notify") == 0 { if n, err := decodeNotification(line); err == nil { // non-blocking write @@ -225,40 +272,70 @@ func (c *Client) messageHandler() { } } } else { - c.res = append(c.res, line) + // Partial response. + buf = append(buf, line) } } else { - err := c.scanErr() - c.err <- err - if errors.Is(err, io.ErrUnexpectedEOF) { - close(c.disconnect) - return + if err := c.scanErr(); c.fatalError(err) { + c.responseErr(err) + } else { + // Ensure that done is closed as scanner has seen an io.EOF. + c.closeDone() } + return } } } +// responseErr sends err to c.response with a timeout to ensure it +// doesn't block forever when multiple errors occur during the +// processing of a single ExecCmd call. +func (c *Client) responseErr(err error) { + t := time.NewTimer(responseErrTimeout) + defer t.Stop() + + select { + case c.response <- response{err: err}: + case <-t.C: + } +} + // workHandler handles commands and keepAlive messages. func (c *Client) workHandler() { + defer c.wg.Done() + for { select { case w := <-c.work: - c.process(w) + if err := c.write([]byte(w)); c.fatalError(err) { + // Command send failed, inform the caller. + c.responseErr(err) + return + } case <-time.After(c.keepAlive): - c.process(keepAliveData) - case <-c.disconnect: + // Send a keep alive to prevent the connection from timing out. + if err := c.write(keepAliveData); c.fatalError(err) { + // We don't send to c.response as no ExecCmd is expecting a + // response and the next caller will get an error. + return + } + case <-c.done: return } } } -func (c *Client) process(data string) { +// write writes data to the clients connection with the configured timeout +// returning any error. +func (c *Client) write(data []byte) error { if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil { - c.err <- err + return fmt.Errorf("set deadline: %w", err) } - if _, err := c.conn.Write([]byte(data)); err != nil { - c.err <- err + if _, err := c.conn.Write(data); err != nil { + return fmt.Errorf("write: %w", err) } + + return nil } // Exec executes cmd on the server and returns the response. @@ -268,37 +345,36 @@ func (c *Client) Exec(cmd string) ([]string, error) { // ExecCmd executes cmd on the server and returns the response. func (c *Client) ExecCmd(cmd *Cmd) ([]string, error) { - if !c.IsConnected() { + select { + case c.work <- cmd.String(): + case <-c.done: return nil, ErrNotConnected } - c.work <- cmd.String() - + var resp response select { - case err := <-c.err: - if err != nil { - return nil, err + case resp = <-c.response: + if resp.err != nil { + return nil, resp.err } case <-time.After(c.timeout): return nil, ErrTimeout } - res := c.res - c.res = nil - if cmd.response != nil { - if err := DecodeResponse(res, cmd.response); err != nil { + if err := DecodeResponse(resp.lines, cmd.response); err != nil { return nil, err } } - return res, nil + return resp.lines, nil } -// IsConnected returns whether the client is connected. +// IsConnected returns true if the client is connected, +// false otherwise. func (c *Client) IsConnected() bool { select { - case <-c.disconnect: + case <-c.done: return false default: return true @@ -307,8 +383,13 @@ func (c *Client) IsConnected() bool { // Close closes the connection to the server. func (c *Client) Close() error { - defer close(c.notify) + defer func() { + c.wg.Wait() + close(c.notify) + }() + // Signal we're expecting EOF. + close(c.closing) _, err := c.Exec("quit") err2 := c.conn.Close() @@ -321,11 +402,23 @@ func (c *Client) Close() error { return nil } -// scanError returns the error from the scanner if non-nil, -// `io.ErrUnexpectedEOF` otherwise. +// scanError returns nil if c is closing else if the scanner returns a +// non-nil error it is returned, otherwise returns `io.ErrUnexpectedEOF`. +// Callers must have seen c.scanner.Scan() return false. func (c *Client) scanErr() error { - if err := c.scanner.Err(); err != nil { - return fmt.Errorf("client: scan: %w", err) + select { + case <-c.closing: + // We know we're closing the connection so ignore any errors + // an return nil. This prevents spurious errors being returned + // to the caller. + return nil + default: + if err := c.scanner.Err(); err != nil { + return fmt.Errorf("scan: %w", err) + } + + // As caller has seen c.scanner.Scan() return false + // this must have been triggered by an unexpected EOF. + return io.ErrUnexpectedEOF } - return io.ErrUnexpectedEOF } diff --git a/connection.go b/connection.go index 8f89c16..06081fe 100644 --- a/connection.go +++ b/connection.go @@ -89,18 +89,13 @@ func (c *sshConnection) Connect(addr string, timeout time.Duration) error { // Read implements io.Reader. func (c *sshConnection) Read(p []byte) (n int, err error) { - if n, err = c.channel.Read(p); err != nil { - return n, fmt.Errorf("ssh connection: read: %w", err) - } - return n, nil + // Don't wrap as it needs to return raw EOF as per https://pkg.go.dev/io#Reader + return c.channel.Read(p) //nolint: wrapcheck } // Write implements io.Writer. func (c *sshConnection) Write(p []byte) (n int, err error) { - if n, err = c.channel.Write(p); err != nil { - return n, fmt.Errorf("ssh connection: write: %w", err) - } - return n, nil + return c.channel.Write(p) //nolint: wrapcheck } // Close implements io.Closer. From 9cb8b6d0af81562b2c28b1859b6c42c3d18fbd4d Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Fri, 24 Feb 2023 18:45:25 +0000 Subject: [PATCH 2/6] ci: increase golangci-lint timeout on windows Increase the golangci-lint timeout on Windows as it constantly takes longer than Linux and Mac leading to timeout failures. --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 5b4909a..a8a4515 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -63,7 +63,7 @@ jobs: outformat: out-format with: version: ${{ matrix.golangci }} - args: "--%outformat% colored-line-number" + args: "--%outformat% colored-line-number --timeout 2m" skip-pkg-cache: true skip-build-cache: true From b83f5d2965064e6209588f0328c710085838b3bf Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Sun, 26 Feb 2023 12:52:05 +0000 Subject: [PATCH 3/6] fix: notify close and keep alive Move notify close so its done as soon as the sender exits. Revert keep alive change as '\n' isn't enough to prevent connection timeout. --- client.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index f0e88c3..a7dd3ae 100644 --- a/client.go +++ b/client.go @@ -36,7 +36,7 @@ var ( // keepAliveData is data which will be ignored by the server used to ensure // the connection is kept alive. - keepAliveData = []byte("\n") + keepAliveData = []byte(" \n") // DefaultTimeout is the default read / write / dial timeout for Clients. DefaultTimeout = 10 * time.Second @@ -243,7 +243,10 @@ func (c *Client) closeDone() { // - ExecCmd responses are sent to c.response. // If a fatal error occurs it stops processing and exits. func (c *Client) messageHandler() { - defer c.wg.Done() + defer func() { + close(c.notify) + c.wg.Done() + }() buf := make([]string, 0, 10) for { @@ -383,10 +386,7 @@ func (c *Client) IsConnected() bool { // Close closes the connection to the server. func (c *Client) Close() error { - defer func() { - c.wg.Wait() - close(c.notify) - }() + defer c.wg.Wait() // Signal we're expecting EOF. close(c.closing) From 5aed7724b3aa2b89dba07f0731b74b3e026a4655 Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Sun, 26 Feb 2023 13:03:54 +0000 Subject: [PATCH 4/6] chore: document notification close Document that notification channel will be closed. --- notification.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/notification.go b/notification.go index 087ed96..dc667c4 100644 --- a/notification.go +++ b/notification.go @@ -40,6 +40,10 @@ type Notification struct { // Notifications returns a read-only channel that outputs received notifications. // +// The channel will be closed when no more notifications will be sent so +// consumers should either range over the returned channel or use the multi +// value version of receive so they can detect when the channel is closed. +// // If you subscribe to server and channel events you will receive duplicate // `cliententerview` and `clientleftview` notifications. // Sending a private message from the client results in a `textmessage` From 667b9c98ff6e1d224c51c23c3369c3803f307d4b Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Mon, 27 Feb 2023 14:56:34 +0000 Subject: [PATCH 5/6] fix: comment typo Fix a comment typo for responseErrTimeout Co-authored-by: HalloTschuess --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index a7dd3ae..74802ff 100644 --- a/client.go +++ b/client.go @@ -26,7 +26,7 @@ const ( // startBufSize is the initial size of allocation for the parse buffer. startBufSize = 4096 - // responseErrTimeout is the timeout use for sending response errors. + // responseErrTimeout is the timeout used for sending response errors. responseErrTimeout = time.Millisecond * 100 ) From d7eb851e52f311bd018687f4d57035bdcf80d03b Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Mon, 27 Feb 2023 15:11:43 +0000 Subject: [PATCH 6/6] fix: remove wrap on io methods Remove wrap on Write, Read and Close methods in mock servers to ensure we don't trigger unexpected behaviour such as io.EOF check failures. --- mockserver_test.go | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/mockserver_test.go b/mockserver_test.go index 7b2c35c..726e346 100644 --- a/mockserver_test.go +++ b/mockserver_test.go @@ -393,11 +393,7 @@ func (c *sshServerShell) Read(b []byte) (int, error) { return 0, err } - n, err := ch.Read(b) - if err != nil { - return n, fmt.Errorf("mock ssh shell: channel read: %w", err) - } - return n, nil + return ch.Read(b) //nolint: wrapcheck } // Write writes to the ssh channel. @@ -407,11 +403,7 @@ func (c *sshServerShell) Write(b []byte) (int, error) { return 0, err } - n, err := ch.Write(b) - if err != nil { - return n, fmt.Errorf("mock ssh shell: channel write: %w", err) - } - return n, nil + return ch.Write(b) //nolint: wrapcheck } // Close closes the ssh channel and connection. @@ -420,8 +412,6 @@ func (c *sshServerShell) Close() error { c.closed = true c.mtx.Unlock() c.cond.Broadcast() - if err := c.Conn.Close(); err != nil { - return fmt.Errorf("mock ssh shell: close: %w", err) - } - return nil + + return c.Conn.Close() //nolint: wrapcheck }