Skip to content

Commit

Permalink
fix: command response processing issues (#40)
Browse files Browse the repository at this point in the history
Fix command response processing issues introduced by notification handling.

This includes:

Random hangs due to writes to channels with no readers.
Potential race when reusing client response buffer.
Close of notify channel while handler was still running.
Eliminated race on IsConnected check in ExecCmd.
Also:

Use just \n for keep-alive to reduce data sent across the wire.
Wrap more errors to improve error reporting.
Remove io.(Reader|Writer) wrapping as that breaks scanner io.EOF handling.

---------

Co-authored-by: HalloTschuess <hallo.ich.f@gmail.com>
  • Loading branch information
stevenh and HalloTschuess authored Apr 25, 2023
1 parent 6cd984d commit 01bb4ee
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
outformat: out-format
with:
version: ${{ matrix.golangci }}
args: "--%outformat% colored-line-number"
args: "--%outformat% colored-line-number --timeout 2m"
skip-pkg-cache: true
skip-build-cache: true

Expand Down
183 changes: 138 additions & 45 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package ts3

import (
"bufio"
"errors"
"fmt"
"io"
"net"
"regexp"
"strings"
"sync"
"time"

"golang.org/x/crypto/ssh"
Expand All @@ -26,13 +26,18 @@ const (
// startBufSize is the initial size of allocation for the parse buffer.
startBufSize = 4096

// keepAliveData is the keepalive data.
keepAliveData = " \n"
// responseErrTimeout is the timeout used for sending response errors.
responseErrTimeout = time.Millisecond * 100
)

var (
// respTrailerRe is the regexp which matches a server response to a command.
respTrailerRe = regexp.MustCompile(`^error id=(\d+) msg=([^ ]+)(.*)`)

// keepAliveData is data which will be ignored by the server used to ensure
// the connection is kept alive.
keepAliveData = []byte(" \n")

// DefaultTimeout is the default read / write / dial timeout for Clients.
DefaultTimeout = 10 * time.Second

Expand All @@ -52,6 +57,11 @@ type Connection interface {
Connect(addr string, timeout time.Duration) error
}

type response struct {
err error
lines []string
}

// Client is a TeamSpeak 3 ServerQuery client.
type Client struct {
conn Connection
Expand All @@ -62,11 +72,13 @@ type Client struct {
maxBufSize int
notifyBufSize int
work chan string
err chan error
response chan response
notify chan Notification
disconnect chan struct{}
res []string
closing chan struct{} // closing is closed to indicate we're closing our connection.
done chan struct{} // done is closed once we're seen a fatal error.
doneOnce sync.Once
connectHeader string
wg sync.WaitGroup

Server *ServerMethods
}
Expand Down Expand Up @@ -151,8 +163,9 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {
maxBufSize: MaxParseTokenSize,
notifyBufSize: DefaultNotifyBufSize,
work: make(chan string),
err: make(chan error),
disconnect: make(chan struct{}),
response: make(chan response),
closing: make(chan struct{}),
done: make(chan struct{}),
connectHeader: DefaultConnectHeader,
}
for _, f := range options {
Expand Down Expand Up @@ -183,7 +196,7 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {

// Read the connection header
if !c.scanner.Scan() {
return nil, c.scanErr()
return nil, fmt.Errorf("client: header: %w", c.scanErr())
}

if l := c.scanner.Text(); l != c.connectHeader {
Expand All @@ -192,30 +205,67 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {

// Slurp the banner
if !c.scanner.Scan() {
return nil, c.scanErr()
return nil, fmt.Errorf("client: banner: %w", c.scanErr())
}

if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
return nil, fmt.Errorf("client: set read deadline: %w", err)
}

// Start handlers
c.wg.Add(2)
go c.messageHandler()
go c.workHandler()

return c, nil
}

// fatalError returns false if err is nil otherwise it ensures
// that done is closed and returns true.
func (c *Client) fatalError(err error) bool {
if err == nil {
return false
}

c.closeDone()
return true
}

// closeDone safely closes c.done.
func (c *Client) closeDone() {
c.doneOnce.Do(func() {
close(c.done)
})
}

// messageHandler scans incoming lines and handles them accordingly.
// - Notifications are sent to c.notify.
// - ExecCmd responses are sent to c.response.
// If a fatal error occurs it stops processing and exits.
func (c *Client) messageHandler() {
defer func() {
close(c.notify)
c.wg.Done()
}()

buf := make([]string, 0, 10)
for {
if c.scanner.Scan() {
line := c.scanner.Text()
//nolint: gocritic
if line == "error id=0 msg=ok" {
c.err <- nil
var resp response
// Avoid creating a new buf if there was no data in the response.
if len(buf) > 0 {
resp.lines = buf
buf = make([]string, 0, 10)
}
c.response <- resp
} else if matches := respTrailerRe.FindStringSubmatch(line); len(matches) == 4 {
c.err <- NewError(matches)
c.response <- response{err: NewError(matches)}
// Avoid creating a new buf if there was no data in the response.
if len(buf) > 0 {
buf = make([]string, 0, 10)
}
} else if strings.Index(line, "notify") == 0 {
if n, err := decodeNotification(line); err == nil {
// non-blocking write
Expand All @@ -225,40 +275,70 @@ func (c *Client) messageHandler() {
}
}
} else {
c.res = append(c.res, line)
// Partial response.
buf = append(buf, line)
}
} else {
err := c.scanErr()
c.err <- err
if errors.Is(err, io.ErrUnexpectedEOF) {
close(c.disconnect)
return
if err := c.scanErr(); c.fatalError(err) {
c.responseErr(err)
} else {
// Ensure that done is closed as scanner has seen an io.EOF.
c.closeDone()
}
return
}
}
}

// responseErr sends err to c.response with a timeout to ensure it
// doesn't block forever when multiple errors occur during the
// processing of a single ExecCmd call.
func (c *Client) responseErr(err error) {
t := time.NewTimer(responseErrTimeout)
defer t.Stop()

select {
case c.response <- response{err: err}:
case <-t.C:
}
}

// workHandler handles commands and keepAlive messages.
func (c *Client) workHandler() {
defer c.wg.Done()

for {
select {
case w := <-c.work:
c.process(w)
if err := c.write([]byte(w)); c.fatalError(err) {
// Command send failed, inform the caller.
c.responseErr(err)
return
}
case <-time.After(c.keepAlive):
c.process(keepAliveData)
case <-c.disconnect:
// Send a keep alive to prevent the connection from timing out.
if err := c.write(keepAliveData); c.fatalError(err) {
// We don't send to c.response as no ExecCmd is expecting a
// response and the next caller will get an error.
return
}
case <-c.done:
return
}
}
}

func (c *Client) process(data string) {
// write writes data to the clients connection with the configured timeout
// returning any error.
func (c *Client) write(data []byte) error {
if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil {
c.err <- err
return fmt.Errorf("set deadline: %w", err)
}
if _, err := c.conn.Write([]byte(data)); err != nil {
c.err <- err
if _, err := c.conn.Write(data); err != nil {
return fmt.Errorf("write: %w", err)
}

return nil
}

// Exec executes cmd on the server and returns the response.
Expand All @@ -268,37 +348,36 @@ func (c *Client) Exec(cmd string) ([]string, error) {

// ExecCmd executes cmd on the server and returns the response.
func (c *Client) ExecCmd(cmd *Cmd) ([]string, error) {
if !c.IsConnected() {
select {
case c.work <- cmd.String():
case <-c.done:
return nil, ErrNotConnected
}

c.work <- cmd.String()

var resp response
select {
case err := <-c.err:
if err != nil {
return nil, err
case resp = <-c.response:
if resp.err != nil {
return nil, resp.err
}
case <-time.After(c.timeout):
return nil, ErrTimeout
}

res := c.res
c.res = nil

if cmd.response != nil {
if err := DecodeResponse(res, cmd.response); err != nil {
if err := DecodeResponse(resp.lines, cmd.response); err != nil {
return nil, err
}
}

return res, nil
return resp.lines, nil
}

// IsConnected returns whether the client is connected.
// IsConnected returns true if the client is connected,
// false otherwise.
func (c *Client) IsConnected() bool {
select {
case <-c.disconnect:
case <-c.done:
return false
default:
return true
Expand All @@ -307,8 +386,10 @@ func (c *Client) IsConnected() bool {

// Close closes the connection to the server.
func (c *Client) Close() error {
defer close(c.notify)
defer c.wg.Wait()

// Signal we're expecting EOF.
close(c.closing)
_, err := c.Exec("quit")
err2 := c.conn.Close()

Expand All @@ -321,11 +402,23 @@ func (c *Client) Close() error {
return nil
}

// scanError returns the error from the scanner if non-nil,
// `io.ErrUnexpectedEOF` otherwise.
// scanError returns nil if c is closing else if the scanner returns a
// non-nil error it is returned, otherwise returns `io.ErrUnexpectedEOF`.
// Callers must have seen c.scanner.Scan() return false.
func (c *Client) scanErr() error {
if err := c.scanner.Err(); err != nil {
return fmt.Errorf("client: scan: %w", err)
select {
case <-c.closing:
// We know we're closing the connection so ignore any errors
// an return nil. This prevents spurious errors being returned
// to the caller.
return nil
default:
if err := c.scanner.Err(); err != nil {
return fmt.Errorf("scan: %w", err)
}

// As caller has seen c.scanner.Scan() return false
// this must have been triggered by an unexpected EOF.
return io.ErrUnexpectedEOF
}
return io.ErrUnexpectedEOF
}
11 changes: 3 additions & 8 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,13 @@ func (c *sshConnection) Connect(addr string, timeout time.Duration) error {

// Read implements io.Reader.
func (c *sshConnection) Read(p []byte) (n int, err error) {
if n, err = c.channel.Read(p); err != nil {
return n, fmt.Errorf("ssh connection: read: %w", err)
}
return n, nil
// Don't wrap as it needs to return raw EOF as per https://pkg.go.dev/io#Reader
return c.channel.Read(p) //nolint: wrapcheck
}

// Write implements io.Writer.
func (c *sshConnection) Write(p []byte) (n int, err error) {
if n, err = c.channel.Write(p); err != nil {
return n, fmt.Errorf("ssh connection: write: %w", err)
}
return n, nil
return c.channel.Write(p) //nolint: wrapcheck
}

// Close implements io.Closer.
Expand Down
18 changes: 4 additions & 14 deletions mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,7 @@ func (c *sshServerShell) Read(b []byte) (int, error) {
return 0, err
}

n, err := ch.Read(b)
if err != nil {
return n, fmt.Errorf("mock ssh shell: channel read: %w", err)
}
return n, nil
return ch.Read(b) //nolint: wrapcheck
}

// Write writes to the ssh channel.
Expand All @@ -407,11 +403,7 @@ func (c *sshServerShell) Write(b []byte) (int, error) {
return 0, err
}

n, err := ch.Write(b)
if err != nil {
return n, fmt.Errorf("mock ssh shell: channel write: %w", err)
}
return n, nil
return ch.Write(b) //nolint: wrapcheck
}

// Close closes the ssh channel and connection.
Expand All @@ -420,8 +412,6 @@ func (c *sshServerShell) Close() error {
c.closed = true
c.mtx.Unlock()
c.cond.Broadcast()
if err := c.Conn.Close(); err != nil {
return fmt.Errorf("mock ssh shell: close: %w", err)
}
return nil

return c.Conn.Close() //nolint: wrapcheck
}
Loading

0 comments on commit 01bb4ee

Please sign in to comment.