Skip to content

Commit

Permalink
improve expect feature
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Dec 16, 2023
1 parent 25937a2 commit c0cc872
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 48 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,25 @@ _`~/` 代表 HOME 目录。在 Windows 中,请将下文的 `~/` 替换成 `C:\

```
Host auto
#!! ExpectCount 2 # 配置自动交互的次数,默认是 0 即无自动交互
#!! ExpectCount 3 # 配置自动交互的次数,默认是 0 即无自动交互
#!! ExpectTimeout 30 # 配置自动交互的超时时间(单位:秒),默认是 30 秒
#!! ExpectPattern1 *password # 配置第一个自动交互的匹配表达式
# 配置第一个自动输入(密文),填 tssh --enc-secret 编码后的字符串,会自动发送 \r 回车
#!! ExpectSendPass1 d7983b4a8ac204bd073ed04741913befd4fbf813ad405d7404cb7d779536f8b87e71106d7780b2
#!! ExpectPattern2 $ # 配置第二个自动交互的匹配表达式
#!! ExpectPattern2 hostname*$ # 配置第二个自动交互的匹配表达式
#!! ExpectSendText2 echo tssh expect\r # 配置第二个自动输入(明文),需要指定 \r 才会发送回车
# 以上 ExpectSendPass? 和 ExpectSendText? 只要二选一即可,若都配置则 ExpectSendPass? 的优先级更高
# --------------------------------------------------
# 在每个 ExpectPattern 匹配之前,可以配置一个或多个可选的匹配,用法如下:
#!! ExpectPattern3 hostname*$ # 配置第三个自动交互的匹配表达式
#!! ExpectSendText3 ssh xxx\r # 配置第三个自动输入,也可以换成 ExpectSendPass3 然后配置密文
#!! ExpectCaseSendText3 yes/no y\r # 在 ExpectPattern3 匹配之前,若遇到 yes/no 则发送 y 并回车
#!! ExpectCaseSendText3 y/n yes\r # 在 ExpectPattern3 匹配之前,若遇到 y/n 则发送 yes 并回车
#!! ExpectCaseSendPass3 token d7... # 在 ExpectPattern3 匹配之前,若遇到 token 则解码并发送 d7...
```

使用 `tssh --debug` 登录,可以看到 `expect` 捕获到的输出,以及其匹配结果和自动输入的交互。

## 记住密码

- 为了兼容标准 ssh ,密码可以单独配置在 `~/.ssh/password` 中,也可以在 `~/.ssh/config` 中加上 `#!!` 前缀。
Expand Down
240 changes: 197 additions & 43 deletions tssh/expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"strconv"
"strings"
"time"
"unicode"
)

const kDefaultExpectTimeout = 30
Expand All @@ -42,11 +43,12 @@ func decodeExpectText(text string) string {
state := byte(0)
for _, c := range text {
if state == 0 {
if c == '\\' {
switch c {
case '\\':
state = '\\'
continue
default:
buf.WriteRune(c)
}
buf.WriteRune(c)
continue
}
state = 0
Expand All @@ -68,88 +70,231 @@ func decodeExpectText(text string) string {
return buf.String()
}

func quoteExpectPattern(pattern string) string {
var buf strings.Builder
for _, c := range pattern {
switch c {
case '*':
buf.WriteString(".*")
case '?', '(', ')', '[', ']', '{', '}', '.', '+', ',', '-', '^', '$', '|', '\\':
buf.WriteRune('\\')
buf.WriteRune(c)
default:
buf.WriteRune(c)
}
}
return buf.String()
}

type caseSend struct {
pattern string
display string
input []byte
re *regexp.Regexp
buffer strings.Builder
}

type caseSendList struct {
writer io.Writer
list []*caseSend
}

func (c *caseSendList) splitConfig(config string) (string, string, error) {
index := strings.IndexFunc(config, unicode.IsSpace)
if index <= 0 {
return "", "", fmt.Errorf("invalid expect case send: %s", config)
}
pattern := strings.TrimSpace(config[:index])
send := strings.TrimSpace(config[index+1:])
if pattern == "" || send == "" {
return "", "", fmt.Errorf("invalid expect case send: %s", config)
}
return pattern, send, nil
}

func (c *caseSendList) addCase(re *regexp.Regexp, pattern, display, input string) {
c.list = append(c.list, &caseSend{
pattern: pattern,
display: display,
input: []byte(input),
re: re,
})
}

func (c *caseSendList) addCaseSendPass(config string) error {
pattern, secret, err := c.splitConfig(config)
if err != nil {
return err
}
expr := quoteExpectPattern(pattern)
re, err := regexp.Compile(expr)
if err != nil {
return fmt.Errorf("compile expect expr [%s] failed: %v", expr, err)
}
pass, err := decodeSecret(secret)
if err != nil {
return fmt.Errorf("decode secret [%s] failed: %v", secret, err)
}
c.addCase(re, pattern, strings.Repeat("*", len(pass))+"\\r", pass+"\r")
return nil
}

func (c *caseSendList) addCaseSendText(config string) error {
pattern, text, err := c.splitConfig(config)
if err != nil {
return err
}
expr := quoteExpectPattern(pattern)
re, err := regexp.Compile(expr)
if err != nil {
return fmt.Errorf("compile expect expr [%s] failed: %v", expr, err)
}
c.addCase(re, pattern, text, decodeExpectText(text))
return nil
}

func (c *caseSendList) handleOutput(output string) {
for _, cs := range c.list {
cs.buffer.WriteString(output)
if cs.re.MatchString(cs.buffer.String()) {
debug("expect case match: %s", cs.pattern)
debug("expect case send: %s", cs.display)
if err := writeAll(c.writer, cs.input); err != nil {
warning("expect send input failed: %v", err)
}
cs.buffer.Reset()
} else {
debug("expect case not match: %s", cs.pattern)
}
}
}

type sshExpect struct {
outputChan chan []byte
outputBuffer strings.Builder
expectContext context.Context
ctx context.Context
out chan []byte
err chan []byte
}

func (e *sshExpect) wrapOutput(reader io.Reader, writer io.Writer) {
for {
func (e *sshExpect) captureOutput(reader io.Reader, ch chan<- []byte) ([]byte, error) {
defer close(ch)
for e.ctx.Err() == nil {
buffer := make([]byte, 32*1024)
n, err := reader.Read(buffer)
if n > 0 {
buf := buffer[:n]
if e.expectContext.Err() != nil {
if err := writeAll(writer, buf); err != nil {
warning("expect wrap output write failed: %v", err)
}
break
select {
case <-e.ctx.Done():
return buf, nil
case ch <- buf:
}
e.outputChan <- buf
}
if err == io.EOF {
return
return nil, err
}
if err != nil {
warning("expect wrap output read failed: %v", err)
warning("expect read output failed: %v", err)
return nil, err
}
}
return nil, nil
}

func (e *sshExpect) wrapOutput(reader io.Reader, writer io.Writer, ch chan []byte) {
buf, err := e.captureOutput(reader, ch)
if err != nil {
return
}
for data := range ch {
if err := writeAll(writer, data); err != nil {
warning("expect write output failed: %v", err)
return
}
}
if buf != nil {
if err := writeAll(writer, buf); err != nil {
warning("expect write output failed: %v", err)
return
}
}
if _, err := io.Copy(writer, reader); err != nil && err != io.EOF {
warning("expect wrap output failed: %v", err)
warning("expect copy output failed: %v", err)
}
}

func (e *sshExpect) waitForPattern(pattern string) error {
expr := strings.ReplaceAll(pattern, "*", ".*")
func (e *sshExpect) waitForPattern(pattern string, caseSends *caseSendList) error {
expr := quoteExpectPattern(pattern)
re, err := regexp.Compile(expr)
if err != nil {
warning("compile expect expr [%s] failed: %v", expr, err)
return err
}
e.outputBuffer.Reset()
var builder strings.Builder
for {
var buf []byte
select {
case <-e.expectContext.Done():
case <-e.ctx.Done():
warning("expect timeout")
return e.expectContext.Err()
case buf := <-e.outputChan:
output := string(buf)
debug("expect output: %s", strconv.QuoteToASCII(output))
e.outputBuffer.WriteString(output)
return e.ctx.Err()
case buf = <-e.out:
case buf = <-e.err:
}
if re.MatchString(e.outputBuffer.String()) {
output := strconv.QuoteToASCII(string(buf))
debug("expect output: %s", output)
caseSends.handleOutput(output[1 : len(output)-1])
builder.WriteString(output[1 : len(output)-1])
if re.MatchString(builder.String()) {
debug("expect match: %s", pattern)
return nil
// cleanup for next expect
for {
select {
case buf = <-e.out:
case buf = <-e.err:
default:
return nil
}
debug("expect output: %s", strconv.QuoteToASCII(string(buf)))
}
} else {
debug("expect not match: %s", pattern)
}
}
}

func (e *sshExpect) execInteractions(args *sshArgs, writer io.Writer, expectCount uint32) {
func (e *sshExpect) execInteractions(alias string, writer io.Writer, expectCount uint32) {
for i := uint32(1); i <= expectCount; i++ {
pattern := getExOptionConfig(args, fmt.Sprintf("ExpectPattern%d", i))
pattern := getExConfig(alias, fmt.Sprintf("ExpectPattern%d", i))
debug("expect pattern %d: %s", i, pattern)
if pattern != "" {
if err := e.waitForPattern(pattern); err != nil {
caseSends := &caseSendList{writer: writer}
for _, cfg := range getAllExConfig(alias, fmt.Sprintf("ExpectCaseSendPass%d", i)) {
if err := caseSends.addCaseSendPass(cfg); err != nil {
warning("Invalid ExpectCaseSendPass%d: %v", i, err)
}
}
for _, cfg := range getAllExConfig(alias, fmt.Sprintf("ExpectCaseSendText%d", i)) {
if err := caseSends.addCaseSendText(cfg); err != nil {
warning("Invalid ExpectCaseSendText%d: %v", i, err)
}
}
if err := e.waitForPattern(pattern, caseSends); err != nil {
return
}
}
if e.expectContext.Err() != nil {
if e.ctx.Err() != nil {
return
}
var input string
pass := getExOptionConfig(args, fmt.Sprintf("ExpectSendPass%d", i))
if pass != "" {
secret, err := decodeSecret(pass)
secret := getExConfig(alias, fmt.Sprintf("ExpectSendPass%d", i))
if secret != "" {
pass, err := decodeSecret(secret)
if err != nil {
warning("decode secret [%s] failed: %v", pass, err)
warning("decode secret [%s] failed: %v", secret, err)
return
}
debug("expect send %d: %s", i, strings.Repeat("*", len(secret)))
input = secret + "\r"
debug("expect send %d: %s\\r", i, strings.Repeat("*", len(pass)))
input = pass + "\r"
} else {
text := getExOptionConfig(args, fmt.Sprintf("ExpectSendText%d", i))
text := getExConfig(alias, fmt.Sprintf("ExpectSendText%d", i))
if text == "" {
continue
}
Expand Down Expand Up @@ -208,11 +353,20 @@ func execExpectInteractions(args *sshArgs, serverIn io.Writer,
}
defer cancel()

expect := &sshExpect{outputChan: make(chan []byte, 1), expectContext: ctx}
go expect.wrapOutput(serverOut, outWriter)
go expect.wrapOutput(serverErr, errWriter)
expect := &sshExpect{
ctx: ctx,
out: make(chan []byte, 10),
err: make(chan []byte, 10),
}
go expect.wrapOutput(serverOut, outWriter, expect.out)
go expect.wrapOutput(serverErr, errWriter, expect.err)

expect.execInteractions(args.Destination, serverIn, expectCount)

expect.execInteractions(args, serverIn, expectCount)
if ctx.Err() == context.DeadlineExceeded {
// enter for shell prompt if timeout
_, _ = serverIn.Write([]byte("\r"))
}

return outReader, errReader
}
7 changes: 4 additions & 3 deletions tssh/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ func expandTokens(str string, args *sshArgs, param *loginParam, tokens string) s
state := byte(0)
for _, c := range str {
if state == 0 {
if c == '%' {
switch c {
case '%':
state = '%'
continue
default:
buf.WriteRune(c)
}
buf.WriteRune(c)
continue
}
state = 0
Expand Down

0 comments on commit c0cc872

Please sign in to comment.