Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix VPN server #616

Merged
merged 7 commits into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion internal/vpn/errors.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
package vpn

import "errors"
import (
"errors"
)

var (
errCouldFindDefaultNetworkGateway = errors.New("could not find default network gateway")
)

// ErrorWithStderr is an error raised by the external process.
// `Err` is an actual error coming from `exec`, while `Stderr` contains
// stderr output of the process.
type ErrorWithStderr struct {
Err error
Stderr []byte
}

// NewErrorWithStderr constructs new `ErrorWithStderr`.
func NewErrorWithStderr(err error, stderr []byte) *ErrorWithStderr {
return &ErrorWithStderr{
Err: err,
Stderr: stderr,
}
}

// Error implements `error`.
func (e *ErrorWithStderr) Error() string {
return e.Err.Error() + ": " + string(e.Stderr)
}
12 changes: 10 additions & 2 deletions internal/vpn/os.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package vpn

import (
"bytes"
"fmt"
"io"
"net"
"os"
"os/exec"
"strings"
)

// LocalNetworkInterfaceIPs gets IPs of all local interfaces.
Expand Down Expand Up @@ -83,14 +86,19 @@ func parseCIDR(ipCIDR string) (ipStr, netmask string, err error) {

//nolint:unparam
func run(bin string, args ...string) error {
fullCmd := bin + " " + strings.Join(args, " ")

cmd := exec.Command(bin, args...) //nolint:gosec

cmd.Stderr = os.Stderr
stderrBuf := bytes.NewBuffer(nil)

cmd.Stderr = io.MultiWriter(os.Stderr, stderrBuf)
cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin

if err := cmd.Run(); err != nil {
return fmt.Errorf("error running command %s: %w", bin, err)
return NewErrorWithStderr(fmt.Errorf("error running command \"%s\": %w", fullCmd, err),
stderrBuf.Bytes())
}

return nil
Expand Down
13 changes: 12 additions & 1 deletion internal/vpn/os_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
package vpn

import (
"errors"
"fmt"
"strconv"
"strings"
)

// SetupTUN sets the allocated TUN interface up, setting its IP, gateway, netmask and MTU.
Expand All @@ -24,7 +26,16 @@ func AddRoute(ipCIDR, gateway string) error {
return fmt.Errorf("error parsing IP CIDR: %w", err)
}

return run("route", "add", "-net", ip, gateway, netmask)
err = run("route", "add", "-net", ip, gateway, netmask)

var e *ErrorWithStderr
if errors.As(err, &e) {
if strings.Contains(string(e.Stderr), "File exists") {
return nil
}
}

return err
}

// DeleteRoute removes route to `ipCIDR` through the `gateway` from the OS routing table.
Expand Down
13 changes: 12 additions & 1 deletion internal/vpn/os_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
package vpn

import (
"errors"
"fmt"
"strconv"
"strings"
)

// SetupTUN sets the allocated TUN interface up, setting its IP, gateway, netmask and MTU.
Expand Down Expand Up @@ -35,7 +37,16 @@ func SetupTUN(ifcName, ipCIDR, gateway string, mtu int) error {

// AddRoute adds route to `ip` with `netmask` through the `gateway` to the OS routing table.
func AddRoute(ip, gateway string) error {
return run("ip", "r", "add", ip, "via", gateway)
err := run("ip", "r", "add", ip, "via", gateway)

var e *ErrorWithStderr
if errors.As(err, &e) {
if strings.Contains(string(e.Stderr), "File exists") {
return nil
}
}

return err
}

// DeleteRoute removes route to `ip` with `netmask` through the `gateway` from the OS routing table.
Expand Down
13 changes: 9 additions & 4 deletions internal/vpn/os_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@ var (
errServerMethodsNotSupported = errors.New("server related methods are not supported for this OS")
)

// AllowSSH allows all SSH traffic (via default 22 port) between `src` and `dst`.
func AllowSSH(_, _ net.IP, _ []net.IP) error {
// GetIPTablesForwardPolicy gets current policy for iptables `forward` chain.
func GetIPTablesForwardPolicy() (string, error) {
return "", errServerMethodsNotSupported
}

// SetIPTablesForwardPolicy sets `policy` for iptables `forward` chain.
func SetIPTablesForwardPolicy(policy string) error {
return errServerMethodsNotSupported
}

// BlockSSH blocks all SSH traffic (via default 22 port) between `src` and `dst`.
func BlockSSH(_, _ net.IP, _ []net.IP) error {
// SetIPTablesForwardAcceptPolicy sets ACCEPT policy for iptables `forward` chain.
func SetIPTablesForwardAcceptPolicy() error {
return errServerMethodsNotSupported
}

Expand Down
47 changes: 38 additions & 9 deletions internal/vpn/os_server_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,49 @@ import (
"fmt"
"net"
"os/exec"
"strings"
)

const (
defaultNetworkInterfaceCMD = "ip addr | awk '/state UP/ {print $2}' | sed 's/.$//'"
getIPv4ForwardingCMD = "sysctl net.ipv4.ip_forward"
getIPv6ForwardingCMD = "sysctl net.ipv6.conf.all.forwarding"
setIPv4ForwardingCMDFmt = "sysctl -w net.ipv4.ip_forward=%s"
setIPv6ForwardingCMDFmt = "sysctl -w net.ipv6.conf.all.forwarding=%s"
enableIPMasqueradingCMDFmt = "iptables -t nat -A POSTROUTING -o %s -j MASQUERADE"
disableIPMasqueradingCMDFmt = "iptables -t nat -D POSTROUTING -o %s -j MASQUERADE"
blockIPToLocalNetCMDFmt = "iptables -I FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -I INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP"
allowIPToLocalNetCMDFmt = "iptables -D FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -D INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP"
defaultNetworkInterfaceCMD = "ip addr | awk '/state UP/ {print $2}' | sed 's/.$//'"
getIPv4ForwardingCMD = "sysctl net.ipv4.ip_forward"
getIPv6ForwardingCMD = "sysctl net.ipv6.conf.all.forwarding"
setIPv4ForwardingCMDFmt = "sysctl -w net.ipv4.ip_forward=%s"
setIPv6ForwardingCMDFmt = "sysctl -w net.ipv6.conf.all.forwarding=%s"
getIPTablesForwardPolicyCMD = "iptables -L | grep \"Chain FORWARD\" | tr -d '()' | awk '{print $4}'"
setIPTablesForwardPolicyCMDFmt = "iptables --policy FORWARD %s"
enableIPMasqueradingCMDFmt = "iptables -t nat -A POSTROUTING -o %s -j MASQUERADE"
disableIPMasqueradingCMDFmt = "iptables -t nat -D POSTROUTING -o %s -j MASQUERADE"
blockIPToLocalNetCMDFmt = "iptables -I FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -I INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP"
allowIPToLocalNetCMDFmt = "iptables -D FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -D INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP"
)

// GetIPTablesForwardPolicy gets current policy for iptables `forward` chain.
func GetIPTablesForwardPolicy() (string, error) {
outputBytes, err := exec.Command("sh", "-c", getIPTablesForwardPolicyCMD).Output()
if err != nil {
return "", fmt.Errorf("error running command %s: %w", getIPTablesForwardPolicyCMD, err)
}

return strings.TrimRight(string(outputBytes), "\n"), nil
}

// SetIPTablesForwardPolicy sets `policy` for iptables `forward` chain.
func SetIPTablesForwardPolicy(policy string) error {
cmd := fmt.Sprintf(setIPTablesForwardPolicyCMDFmt, policy)
if err := exec.Command("sh", "-c", cmd).Run(); err != nil { //nolint:gosec
return fmt.Errorf("error running command %s: %w", cmd, err)
}

return nil
}

// SetIPTablesForwardAcceptPolicy sets ACCEPT policy for iptables `forward` chain.
func SetIPTablesForwardAcceptPolicy() error {
const policy = "ACCEPT"
return SetIPTablesForwardPolicy(policy)
}

// AllowIPToLocalNetwork allows all the packets coming from `source`
// to private IP ranges.
func AllowIPToLocalNetwork(src, dst net.IP) error {
Expand Down
24 changes: 24 additions & 0 deletions internal/vpn/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Server struct {
defaultNetworkInterfaceIPs []net.IP
ipv4ForwardingVal string
ipv6ForwardingVal string
iptablesForwardPolicy string
}

// NewServer creates VPN server instance.
Expand Down Expand Up @@ -58,10 +59,18 @@ func NewServer(cfg ServerConfig, l logrus.FieldLogger) (*Server, error) {
l.Infoln("Old IP forwarding values:")
l.Infof("IPv4: %s, IPv6: %s", ipv4ForwardingVal, ipv6ForwardingVal)

iptablesForwarPolicy, err := GetIPTablesForwardPolicy()
if err != nil {
return nil, fmt.Errorf("error getting iptables forward policy: %w", err)
}

l.Infof("Old iptables forward policy: %s", iptablesForwarPolicy)

s.defaultNetworkInterface = defaultNetworkIfc
s.defaultNetworkInterfaceIPs = defaultNetworkIfcIPs
s.ipv4ForwardingVal = ipv4ForwardingVal
s.ipv6ForwardingVal = ipv6ForwardingVal
s.iptablesForwardPolicy = iptablesForwarPolicy

return s, nil
}
Expand Down Expand Up @@ -111,6 +120,21 @@ func (s *Server) Serve(l net.Listener) error {
}
}()

if err := SetIPTablesForwardAcceptPolicy(); err != nil {
serveErr = fmt.Errorf("error settings iptables forward policy to ACCEPT")
return
}

s.log.Infoln("Set iptables forward policy to ACCEPT")

defer func() {
if err := SetIPTablesForwardPolicy(s.iptablesForwardPolicy); err != nil {
s.log.WithError(err).Errorf("Error setting iptables forward policy to %s", s.iptablesForwardPolicy)
} else {
s.log.Infof("Restored iptables forward policy to %s", s.iptablesForwardPolicy)
}
}()

s.lisMx.Lock()
s.lis = l
s.lisMx.Unlock()
Expand Down