Skip to content

Commit

Permalink
feat: implement context-based timeout for commands
Browse files Browse the repository at this point in the history
- Add context package import
- Replace timeout channel with context-based timeout
- Improve error message to include context timeout error
- Update test to match new error message format
- Add new test for command timeout functionality

Signed-off-by: Bo-Yi Wu <appleboy.tw@gmail.com>
  • Loading branch information
appleboy committed Dec 6, 2024
1 parent dc56456 commit 13a382d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
8 changes: 5 additions & 3 deletions easyssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package easyssh

import (
"bufio"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -357,7 +358,8 @@ func (ssh_conf *MakeConfig) Stream(command string, timeout ...time.Duration) (<-
if len(timeout) > 0 {
executeTimeout = timeout[0]
}
timeoutChan := time.After(executeTimeout)
ctxTimeout, cancel := context.WithTimeout(context.Background(), executeTimeout)
defer cancel()
res := make(chan struct{}, 1)
var resWg sync.WaitGroup
resWg.Add(2)
Expand Down Expand Up @@ -398,8 +400,8 @@ func (ssh_conf *MakeConfig) Stream(command string, timeout ...time.Duration) (<-
case <-res:
errChan <- session.Wait()
doneChan <- true
case <-timeoutChan:
errChan <- fmt.Errorf("Run Command Timeout")
case <-ctxTimeout.Done():
errChan <- fmt.Errorf("Run Command Timeout: %v", ctxTimeout.Err())
doneChan <- false
}
}(stdoutScanner, stderrScanner, stdoutChan, stderrChan, doneChan, errChan)
Expand Down
20 changes: 18 additions & 2 deletions easyssh_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package easyssh

import (
"context"
"os"
"os/user"
"path"
Expand All @@ -20,7 +21,6 @@ func getHostPublicKeyFile(keypath string) (ssh.PublicKey, error) {
}

pubkey, _, _, _, err = ssh.ParseAuthorizedKey(buf)

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -169,7 +169,7 @@ func TestRunCommand(t *testing.T) {
assert.Equal(t, "", errStr)
assert.False(t, isTimeout)
assert.Error(t, err)
assert.Equal(t, "Run Command Timeout", err.Error())
assert.Equal(t, "Run Command Timeout: "+context.DeadlineExceeded.Error(), err.Error())

// test exit code
outStr, errStr, isTimeout, err = ssh.Run("exit 1")
Expand Down Expand Up @@ -496,3 +496,19 @@ func TestSudoCommand(t *testing.T) {
assert.True(t, isTimeout)
assert.NoError(t, err)
}

func TestCommandTimeout(t *testing.T) {
ssh := &MakeConfig{
Server: "localhost",
User: "root",
Port: "22",
KeyPath: "./tests/.ssh/id_rsa",
}

outStr, errStr, isTimeout, err := ssh.Run("whoami; sleep 2", 1*time.Second)
assert.Equal(t, "root\n", outStr)
assert.Equal(t, "", errStr)
assert.False(t, isTimeout)
assert.NotNil(t, err)
assert.Equal(t, "Run Command Timeout: "+context.DeadlineExceeded.Error(), err.Error())
}

0 comments on commit 13a382d

Please sign in to comment.