-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsftp.go
142 lines (130 loc) · 3.28 KB
/
sftp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package backend
import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// SFTP is interface for handling authorized_keys files
type SFTP interface {
Connect(keyfile, host, user string) error
Write(filename, data string) error
Read(filename string) ([]byte, error)
Close()
}
// SFTPConn is a wrapper around sftp.Client, implements SFTP interface
type SFTPConn struct {
host SFTPMockHost
client *sftp.Client
mock bool
alias string
expected string
testHosts map[string]SFTPMockHost
testError bool
}
// SFTPMockHost is a build-in mock for testing
type SFTPMockHost struct {
Host string
User string
File string
}
// Connect connects to the host using the given keyfile and user
func (s *SFTPConn) Connect(keyfile, host, user string) error {
if strings.HasPrefix(keyfile, "~/") {
home, _ := os.UserHomeDir()
keyfile = filepath.Join(home, "/", keyfile[2:])
}
if serv, ok := s.testHosts[host]; ok || s.mock {
s.host = serv
s.alias = host
if s.testHosts == nil {
s.testHosts = map[string]SFTPMockHost{}
}
return nil
}
key, err := os.ReadFile(keyfile)
if err != nil {
return fmt.Errorf("unable to read private key: %v", err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return fmt.Errorf("unable to parse private key: %v", err)
}
config := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
Timeout: 3 * time.Second,
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// log.Printf("host key callback %s %v %s", hostname, remote, key)
return nil
},
}
connection, err := ssh.Dial("tcp", host, config)
if err != nil {
return err
}
s.client, err = sftp.NewClient(connection)
if err != nil {
return err
}
return nil
}
// GetHosts is used for testing, returns the list of hosts
func (s *SFTPConn) GetHosts() map[string]SFTPMockHost {
return s.testHosts
}
// SetError is used for testing, sets the error flag
func (s *SFTPConn) SetError(willError bool) {
s.testError = willError
}
// Write writes the given data to the authorized_keys file on the remote host
// when data is empty, or if it's running from tests, simply returns
func (s *SFTPConn) Write(filename, data string) error {
if data == "" || s.testError {
return fmt.Errorf("empty data, not writing it")
}
if s.mock {
if (s.expected != "" && data != s.expected) || s.testError {
return fmt.Errorf("data is not as expected: '%s' instead of '%s'", data, s.expected)
}
s.host.File = data
s.testHosts[s.alias] = s.host
return nil
}
f, err := s.client.OpenFile(filename, os.O_RDWR|os.O_TRUNC)
if err != nil {
return err
}
defer f.Close()
if _, err := f.Write([]byte(data)); err != nil {
return err
}
return nil
}
// Read reads the authorized_keys file from the remote host
// when running from tests, returns the mocked data
func (s *SFTPConn) Read(filename string) ([]byte, error) {
if s.mock {
if s.testError {
return nil, fmt.Errorf("test error reading file")
}
return []byte(s.host.File), nil
}
f, err := s.client.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
// Close closes the connection to the remote host
func (s *SFTPConn) Close() {
if !s.mock {
s.client.Close()
}
}