Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: command response processing issues #40

Merged
merged 6 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
outformat: out-format
with:
version: ${{ matrix.golangci }}
args: "--%outformat% colored-line-number"
args: "--%outformat% colored-line-number --timeout 2m"
skip-pkg-cache: true
skip-build-cache: true

Expand Down
183 changes: 138 additions & 45 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package ts3

import (
"bufio"
"errors"
"fmt"
"io"
"net"
"regexp"
"strings"
"sync"
"time"

"golang.org/x/crypto/ssh"
Expand All @@ -26,13 +26,18 @@ const (
// startBufSize is the initial size of allocation for the parse buffer.
startBufSize = 4096

// keepAliveData is the keepalive data.
keepAliveData = " \n"
// responseErrTimeout is the timeout used for sending response errors.
responseErrTimeout = time.Millisecond * 100
)

var (
// respTrailerRe is the regexp which matches a server response to a command.
respTrailerRe = regexp.MustCompile(`^error id=(\d+) msg=([^ ]+)(.*)`)

// keepAliveData is data which will be ignored by the server used to ensure
// the connection is kept alive.
keepAliveData = []byte(" \n")

// DefaultTimeout is the default read / write / dial timeout for Clients.
DefaultTimeout = 10 * time.Second

Expand All @@ -52,6 +57,11 @@ type Connection interface {
Connect(addr string, timeout time.Duration) error
}

type response struct {
err error
lines []string
}

// Client is a TeamSpeak 3 ServerQuery client.
type Client struct {
conn Connection
Expand All @@ -62,11 +72,13 @@ type Client struct {
maxBufSize int
notifyBufSize int
work chan string
err chan error
response chan response
notify chan Notification
disconnect chan struct{}
res []string
closing chan struct{} // closing is closed to indicate we're closing our connection.
done chan struct{} // done is closed once we're seen a fatal error.
doneOnce sync.Once
connectHeader string
wg sync.WaitGroup

Server *ServerMethods
}
Expand Down Expand Up @@ -151,8 +163,9 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {
maxBufSize: MaxParseTokenSize,
notifyBufSize: DefaultNotifyBufSize,
work: make(chan string),
err: make(chan error),
disconnect: make(chan struct{}),
response: make(chan response),
closing: make(chan struct{}),
done: make(chan struct{}),
connectHeader: DefaultConnectHeader,
}
for _, f := range options {
Expand Down Expand Up @@ -183,7 +196,7 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {

// Read the connection header
if !c.scanner.Scan() {
return nil, c.scanErr()
return nil, fmt.Errorf("client: header: %w", c.scanErr())
}

if l := c.scanner.Text(); l != c.connectHeader {
Expand All @@ -192,30 +205,67 @@ func NewClient(addr string, options ...func(c *Client) error) (*Client, error) {

// Slurp the banner
if !c.scanner.Scan() {
return nil, c.scanErr()
return nil, fmt.Errorf("client: banner: %w", c.scanErr())
}

if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
return nil, fmt.Errorf("client: set read deadline: %w", err)
}

// Start handlers
c.wg.Add(2)
go c.messageHandler()
go c.workHandler()

return c, nil
}

// fatalError returns false if err is nil otherwise it ensures
// that done is closed and returns true.
func (c *Client) fatalError(err error) bool {
if err == nil {
return false
}

c.closeDone()
return true
}

// closeDone safely closes c.done.
func (c *Client) closeDone() {
c.doneOnce.Do(func() {
close(c.done)
})
}

// messageHandler scans incoming lines and handles them accordingly.
// - Notifications are sent to c.notify.
// - ExecCmd responses are sent to c.response.
// If a fatal error occurs it stops processing and exits.
func (c *Client) messageHandler() {
defer func() {
close(c.notify)
c.wg.Done()
}()

buf := make([]string, 0, 10)
for {
if c.scanner.Scan() {
line := c.scanner.Text()
//nolint: gocritic
if line == "error id=0 msg=ok" {
c.err <- nil
var resp response
// Avoid creating a new buf if there was no data in the response.
if len(buf) > 0 {
resp.lines = buf
buf = make([]string, 0, 10)
}
c.response <- resp
} else if matches := respTrailerRe.FindStringSubmatch(line); len(matches) == 4 {
c.err <- NewError(matches)
c.response <- response{err: NewError(matches)}
// Avoid creating a new buf if there was no data in the response.
if len(buf) > 0 {
buf = make([]string, 0, 10)
}
} else if strings.Index(line, "notify") == 0 {
if n, err := decodeNotification(line); err == nil {
// non-blocking write
Expand All @@ -225,40 +275,70 @@ func (c *Client) messageHandler() {
}
}
} else {
c.res = append(c.res, line)
// Partial response.
buf = append(buf, line)
}
} else {
err := c.scanErr()
c.err <- err
if errors.Is(err, io.ErrUnexpectedEOF) {
close(c.disconnect)
return
if err := c.scanErr(); c.fatalError(err) {
c.responseErr(err)
} else {
// Ensure that done is closed as scanner has seen an io.EOF.
c.closeDone()
}
return
}
}
}

// responseErr sends err to c.response with a timeout to ensure it
// doesn't block forever when multiple errors occur during the
// processing of a single ExecCmd call.
func (c *Client) responseErr(err error) {
t := time.NewTimer(responseErrTimeout)
defer t.Stop()

select {
case c.response <- response{err: err}:
case <-t.C:
}
}

// workHandler handles commands and keepAlive messages.
func (c *Client) workHandler() {
defer c.wg.Done()

for {
select {
case w := <-c.work:
c.process(w)
if err := c.write([]byte(w)); c.fatalError(err) {
// Command send failed, inform the caller.
c.responseErr(err)
return
}
case <-time.After(c.keepAlive):
c.process(keepAliveData)
case <-c.disconnect:
// Send a keep alive to prevent the connection from timing out.
if err := c.write(keepAliveData); c.fatalError(err) {
// We don't send to c.response as no ExecCmd is expecting a
// response and the next caller will get an error.
return
}
case <-c.done:
return
}
}
}

func (c *Client) process(data string) {
// write writes data to the clients connection with the configured timeout
// returning any error.
func (c *Client) write(data []byte) error {
if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil {
c.err <- err
return fmt.Errorf("set deadline: %w", err)
}
if _, err := c.conn.Write([]byte(data)); err != nil {
c.err <- err
if _, err := c.conn.Write(data); err != nil {
return fmt.Errorf("write: %w", err)
}

return nil
}

// Exec executes cmd on the server and returns the response.
Expand All @@ -268,37 +348,36 @@ func (c *Client) Exec(cmd string) ([]string, error) {

// ExecCmd executes cmd on the server and returns the response.
func (c *Client) ExecCmd(cmd *Cmd) ([]string, error) {
if !c.IsConnected() {
select {
case c.work <- cmd.String():
case <-c.done:
return nil, ErrNotConnected
}

c.work <- cmd.String()

var resp response
select {
case err := <-c.err:
if err != nil {
return nil, err
case resp = <-c.response:
if resp.err != nil {
return nil, resp.err
}
case <-time.After(c.timeout):
return nil, ErrTimeout
}

res := c.res
c.res = nil

if cmd.response != nil {
if err := DecodeResponse(res, cmd.response); err != nil {
if err := DecodeResponse(resp.lines, cmd.response); err != nil {
return nil, err
}
}

return res, nil
return resp.lines, nil
}

// IsConnected returns whether the client is connected.
// IsConnected returns true if the client is connected,
// false otherwise.
func (c *Client) IsConnected() bool {
select {
case <-c.disconnect:
case <-c.done:
return false
default:
return true
Expand All @@ -307,8 +386,10 @@ func (c *Client) IsConnected() bool {

// Close closes the connection to the server.
func (c *Client) Close() error {
defer close(c.notify)
defer c.wg.Wait()

// Signal we're expecting EOF.
close(c.closing)
_, err := c.Exec("quit")
err2 := c.conn.Close()

Expand All @@ -321,11 +402,23 @@ func (c *Client) Close() error {
return nil
}

// scanError returns the error from the scanner if non-nil,
// `io.ErrUnexpectedEOF` otherwise.
// scanError returns nil if c is closing else if the scanner returns a
// non-nil error it is returned, otherwise returns `io.ErrUnexpectedEOF`.
// Callers must have seen c.scanner.Scan() return false.
func (c *Client) scanErr() error {
if err := c.scanner.Err(); err != nil {
return fmt.Errorf("client: scan: %w", err)
select {
case <-c.closing:
// We know we're closing the connection so ignore any errors
// an return nil. This prevents spurious errors being returned
// to the caller.
return nil
default:
if err := c.scanner.Err(); err != nil {
return fmt.Errorf("scan: %w", err)
}

// As caller has seen c.scanner.Scan() return false
// this must have been triggered by an unexpected EOF.
return io.ErrUnexpectedEOF
}
return io.ErrUnexpectedEOF
}
11 changes: 3 additions & 8 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,13 @@ func (c *sshConnection) Connect(addr string, timeout time.Duration) error {

// Read implements io.Reader.
func (c *sshConnection) Read(p []byte) (n int, err error) {
if n, err = c.channel.Read(p); err != nil {
return n, fmt.Errorf("ssh connection: read: %w", err)
}
return n, nil
// Don't wrap as it needs to return raw EOF as per https://pkg.go.dev/io#Reader
return c.channel.Read(p) //nolint: wrapcheck
}

// Write implements io.Writer.
func (c *sshConnection) Write(p []byte) (n int, err error) {
if n, err = c.channel.Write(p); err != nil {
return n, fmt.Errorf("ssh connection: write: %w", err)
}
return n, nil
return c.channel.Write(p) //nolint: wrapcheck
}

// Close implements io.Closer.
Expand Down
18 changes: 4 additions & 14 deletions mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,7 @@ func (c *sshServerShell) Read(b []byte) (int, error) {
return 0, err
}

n, err := ch.Read(b)
if err != nil {
return n, fmt.Errorf("mock ssh shell: channel read: %w", err)
}
return n, nil
return ch.Read(b) //nolint: wrapcheck
}

// Write writes to the ssh channel.
Expand All @@ -407,11 +403,7 @@ func (c *sshServerShell) Write(b []byte) (int, error) {
return 0, err
}

n, err := ch.Write(b)
if err != nil {
return n, fmt.Errorf("mock ssh shell: channel write: %w", err)
}
return n, nil
return ch.Write(b) //nolint: wrapcheck
}

// Close closes the ssh channel and connection.
Expand All @@ -420,8 +412,6 @@ func (c *sshServerShell) Close() error {
c.closed = true
c.mtx.Unlock()
c.cond.Broadcast()
if err := c.Conn.Close(); err != nil {
return fmt.Errorf("mock ssh shell: close: %w", err)
}
return nil

return c.Conn.Close() //nolint: wrapcheck
}
Loading