-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsshmgr.go
174 lines (144 loc) · 3.71 KB
/
sshmgr.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
package sshmgr
import (
"errors"
"sync"
"sync/atomic"
"time"
"github.com/brunotm/sshmgr/locker"
"github.com/pkg/sftp"
)
var (
errManagerClosed = errors.New("manager closed")
)
// Manager for shared ssh and sftp clients
type Manager struct {
mtx sync.RWMutex
gcInterval time.Duration
clientTTL int64
locker *locker.Locker
clients map[string]*Client
closeChan chan struct{}
}
// New creates a new Manager.
// clientTTL specifies the maximum amount of time after which it was last accessed that client
// will be kept alive in the manager without open references.
// The client last access time is updated when the client is released
// gcInterval specifies the interval the manager will try to remove unused clients
func New(clientTTL, gcInterval time.Duration) (manager *Manager) {
manager = &Manager{
mtx: sync.RWMutex{},
gcInterval: gcInterval,
clientTTL: int64(clientTTL.Seconds()),
locker: locker.New(),
clients: map[string]*Client{},
closeChan: make(chan struct{}),
}
go manager.gc()
return manager
}
// Close all running clients and shutdown the manager
func (m *Manager) Close() {
close(m.closeChan)
}
func (m *Manager) getClient(id string) (client *Client) {
m.mtx.RLock()
client = m.clients[id]
m.mtx.RUnlock()
return client
}
func (m *Manager) delClient(id string) {
m.mtx.RLock()
delete(m.clients, id)
m.mtx.RUnlock()
}
func (m *Manager) setClient(id string, client *Client) {
m.mtx.RLock()
m.clients[id] = client
m.mtx.RUnlock()
}
// SSHClient returns an active managed client or create a new one on demand.
// Clients must be closed after usage so they can be removed when there are no references
func (m *Manager) SSHClient(config ClientConfig) (client *Client, err error) {
select {
case <-m.closeChan:
return nil, errManagerClosed
default:
}
id := config.id()
m.locker.Lock(id)
defer m.locker.Unlock(id)
// Get a client for this config
client = m.getClient(config.id())
if client != nil {
// Check if client is valid
_, _, err = client.client.SendRequest("sshmgr", true, nil)
if err == nil {
client.incr()
client.conn.SetDeadline(time.Now().Add(config.ConnDeadline))
return client, nil
}
m.delClient(id)
}
if client, err = newClient(config); err != nil {
return nil, err
}
// Add the client to the manager, increment the reference count
// and set the current deadline
client.incr()
m.setClient(id, client)
client.conn.SetDeadline(time.Now().Add(config.ConnDeadline))
return client, nil
}
// SFTPClient creates a session from a active managed client or create a new one on demand.
// Clients must be closed after usage so they can be removed when they have no references
func (m *Manager) SFTPClient(config ClientConfig) (session *SFTPClient, err error) {
// Get a client for this config
client, err := m.SSHClient(config)
if err != nil {
return nil, err
}
// Create a SFTP session
sftpClient, err := sftp.NewClient(client.client)
if err != nil {
return nil, err
}
msftp := &SFTPClient{}
msftp.client = client
msftp.Client = sftpClient
return msftp, nil
}
func (m *Manager) gc() {
ticker := time.NewTicker(m.gcInterval)
defer ticker.Stop()
for {
select {
case <-m.closeChan:
m.collect(true)
return
case <-ticker.C:
m.collect(false)
}
}
}
// collect unreferenced and expired clients
func (m *Manager) collect(shutdown bool) {
now := time.Now().Unix()
m.mtx.Lock()
for id := range m.clients {
m.locker.Lock(id)
client := m.clients[id]
if shutdown {
delete(m.clients, id)
client.Close()
continue
}
if client.refcount() == 0 {
if (now - atomic.LoadInt64(&client.atime)) >= m.clientTTL {
delete(m.clients, id)
client.Close()
}
}
m.locker.Unlock(id)
}
m.mtx.Unlock()
}