-
Notifications
You must be signed in to change notification settings - Fork 9
/
main.go
343 lines (301 loc) · 9.33 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
package main
import (
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
)
// Thread safe rules
var Rules CSafeRule
// The config file name that has the default of rules.json
var ConfigFileName = "rules.json"
// Thread safe simultaneous connections
var SimultaneousConnections CSafeConnections
// Log level (higher = more logs)
var Verbose = 1
// If true that program will save the config file just before it exits
var SaveBeforeExit = true
// The version of program
const Version = "1.5.0 / Build 14"
// A struct to count the active connections on each port
type CSafeConnections struct {
SimultaneousConnections []int
mu sync.RWMutex
}
// Thread safe rules to read and write it
type CSafeRule struct {
Rules []Rule
mu sync.RWMutex
}
// The main rule struct
type Rule struct {
// Name does not have any effect on anything
Name string
// The port to listen on
Listen uint16
// The destination to forward the packets to
Forward string
// The remaining bytes that the user can use
// Note that this variable is int64 not uint64. Because if it was uin64, it would overflow to 2^64-1.
Quota int64
// The last time this port can be accessed in UTC
// 0 means that this rules does not expire
ExpireDate int64
// Number of simultaneous connections allowed * 2
// 0 means that there is no limit
Simultaneous int
}
// The config file struct
type Config struct {
// Interval of saving files in seconds
SaveDuration int
// The timeout of all connections
// Values equal or lower than 0 disable the timeout
Timeout int64
// All of the forwarding rules
Rules []Rule
}
// Is timeout enabled?
var EnableTimeOut = true
// The timout value in time.Duration type
var TimeoutDuration time.Duration
func main() {
{ // Parse arguments
flag.StringVar(&ConfigFileName, "config", "rules.json", "The config filename")
flag.IntVar(&Verbose, "verbose", 1, "Verbose level: 0->None(Mostly Silent), 1->Quota reached, expiry date and typical errors, 2->Connection flood 3->Timeout drops 4->All logs and errors")
help := flag.Bool("h", false, "Show help")
flag.BoolVar(&SaveBeforeExit, "no-exit-save", false, "Set this argument to disable the save of rules before exiting")
flag.Parse()
if *help {
fmt.Println("Created by Hirbod Behnam")
fmt.Println("Source at https://github.com/HirbodBehnam/PortForwarder")
fmt.Println("Version", Version)
flag.PrintDefaults()
os.Exit(0)
}
if Verbose != 0 {
fmt.Println("Verbose mode on level", Verbose)
}
SaveBeforeExit = !SaveBeforeExit
}
// Read config file
var conf Config
{
confF, err := os.ReadFile(ConfigFileName)
if err != nil {
panic("Cannot read the config file. (io Error) " + err.Error())
}
err = json.Unmarshal(confF, &conf)
if err != nil {
panic("Cannot read the config file. (Parse Error) " + err.Error())
}
Rules.Rules = conf.Rules
SimultaneousConnections.SimultaneousConnections = make([]int, len(Rules.Rules))
if conf.Timeout <= 0 {
logVerbose(1, "Disabled timeout")
EnableTimeOut = false
} else {
TimeoutDuration = time.Duration(conf.Timeout) * time.Second
logVerbose(3, "Set timeout to", TimeoutDuration)
}
}
// Start listeners
for index, rule := range Rules.Rules {
go func(i int, loopRule Rule) {
if loopRule.Quota < 0 { // If the quota is already reached why listen for connections?
log.Println("Skip enabling forward on port", loopRule.Listen, "because the quota is reached.")
return
}
if loopRule.ExpireDate != 0 && loopRule.ExpireDate < time.Now().Unix() { // Same thing goes with expire date
log.Println("Skip enabling forward on port", loopRule.Listen, "because this rule is expired.")
return
}
log.Println("Forwarding from", loopRule.Listen, "port to", loopRule.Forward)
ln, err := net.Listen("tcp", ":"+strconv.Itoa(int(loopRule.Listen))) // Initialize the listener
if err != nil {
panic(err) // This will terminate the program
}
for {
conn, err := ln.Accept() // The loop will be held here;
Rules.mu.RLock() // Lock the rules mutex to read the quota and expire date
if Rules.Rules[i].Quota < 0 { // Check the quota
Rules.mu.RUnlock()
logVerbose(1, "Quota reached for port", loopRule.Listen, "pointing to", loopRule.Forward)
if err == nil {
_ = conn.Close()
}
saveConfig(conf) // Force write the config file
break
}
if Rules.Rules[i].ExpireDate != 0 && Rules.Rules[i].ExpireDate < time.Now().Unix() { // Check expire date
Rules.mu.RUnlock()
logVerbose(1, "Expire date reached for port", loopRule.Listen, "pointing to", loopRule.Forward)
if err == nil {
_ = conn.Close()
}
saveConfig(conf) // Force write the config file
break
}
Rules.mu.RUnlock()
if err != nil {
logVerbose(1, "Error on accepting connection:", err.Error())
continue
}
go handleRequest(conn, i, loopRule)
}
}(index, rule)
}
// Save config file in intervals
go func() {
sd := conf.SaveDuration
if sd == 0 {
sd = 600
conf.SaveDuration = 600
}
saveInterval := time.Duration(sd) * time.Second
for {
time.Sleep(saveInterval) // Save file every x seconds
saveConfig(conf)
}
}()
// https://gobyexample.com/signals
sigs := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() { //This will wait for a signal
<-sigs
done <- true
}()
log.Println("Ctrl + C to stop")
<-done
if SaveBeforeExit {
saveConfig(conf) // Save the config file one last time before exiting
}
log.Println("Exiting")
}
// Saves the config file
func saveConfig(config Config) {
Rules.mu.RLock() //Lock to read the rules
config.Rules = Rules.Rules
b, _ := json.Marshal(config)
Rules.mu.RUnlock()
err := os.WriteFile(ConfigFileName, b, 0644)
if err != nil {
logVerbose(1, "Error re-writing rules: ", err)
} else {
logVerbose(4, "Saved the config")
}
}
// All incoming connections end up here
// Index is the rule index
func handleRequest(conn net.Conn, index int, r Rule) {
// Send a clone of rules to here to avoid need of locking mutex
SimultaneousConnections.mu.RLock()
if r.Simultaneous != 0 && SimultaneousConnections.SimultaneousConnections[index] >= (r.Simultaneous*2) { //If we have reached quota just terminate the connection; 0 means no limits
logVerbose(2, "Blocking new connection for port", r.Listen, "because the connection limit is reached. The current active connections count is", SimultaneousConnections.SimultaneousConnections[index]/2)
SimultaneousConnections.mu.RUnlock()
_ = conn.Close()
return
}
SimultaneousConnections.mu.RUnlock()
// Open a connection to remote host
proxy, err := net.Dial("tcp", r.Forward)
if err != nil {
logVerbose(1, "Error on dialing remote host:", err.Error())
_ = conn.Close()
return
}
// Increase the connection count
SimultaneousConnections.mu.Lock()
SimultaneousConnections.SimultaneousConnections[index] += 2 // Two is added; One for client to server and another for server to client
logVerbose(4, "Accepting a connection from", conn.RemoteAddr(), "; Now", SimultaneousConnections.SimultaneousConnections[index], "SimultaneousConnections")
SimultaneousConnections.mu.Unlock()
go copyIO(conn, proxy, index) // client -> server
go copyIO(proxy, conn, index) // server -> client
}
// Copies the src to dest
// Index is the rule index
func copyIO(src, dest net.Conn, index int) {
defer src.Close()
defer dest.Close()
// r is the amount of bytes transferred
var r int64
var err error
if EnableTimeOut {
r, err = copyBuffer(dest, src)
} else {
r, err = io.Copy(dest, src) // if timeout is not enabled just use the original io.copy
}
if err != nil {
if strings.Contains(err.Error(), "i/o timeout") {
logVerbose(3, "A connection timed out from", src.RemoteAddr(), "to", dest.RemoteAddr())
} else if strings.HasPrefix(err.Error(), "cannot set timeout for") {
if strings.HasSuffix(err.Error(), "use of closed network connection") {
logVerbose(4, err.Error())
} else {
logVerbose(1, err.Error())
}
} else {
logVerbose(4, "Error on copyBuffer:", err.Error())
}
}
Rules.mu.Lock() // lock to change the amount of data transferred
Rules.Rules[index].Quota -= r
Rules.mu.Unlock()
SimultaneousConnections.mu.Lock()
SimultaneousConnections.SimultaneousConnections[index]-- // this will run twice
logVerbose(4, "Closing a connection from", src.RemoteAddr(), "; Connections Now:", SimultaneousConnections.SimultaneousConnections[index])
SimultaneousConnections.mu.Unlock()
}
func copyBuffer(dst, src net.Conn) (written int64, err error) {
buf := make([]byte, 32768) // 32kb buffer
for {
err = src.SetDeadline(time.Now().Add(TimeoutDuration))
if err != nil {
err = errors.New("cannot set timeout for src: " + err.Error())
break
}
nr, er := src.Read(buf)
if nr > 0 {
err = dst.SetDeadline(time.Now().Add(TimeoutDuration))
if err != nil {
err = errors.New("cannot set timeout for dest: " + err.Error())
break
}
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
func logVerbose(level int, msg ...interface{}) {
if Verbose >= level {
log.Println(msg)
}
}