diff --git a/internal/vpn/errors.go b/internal/vpn/errors.go index 06b6264f83..2ac32ee559 100644 --- a/internal/vpn/errors.go +++ b/internal/vpn/errors.go @@ -1,7 +1,30 @@ package vpn -import "errors" +import ( + "errors" +) var ( errCouldFindDefaultNetworkGateway = errors.New("could not find default network gateway") ) + +// ErrorWithStderr is an error raised by the external process. +// `Err` is an actual error coming from `exec`, while `Stderr` contains +// stderr output of the process. +type ErrorWithStderr struct { + Err error + Stderr []byte +} + +// NewErrorWithStderr constructs new `ErrorWithStderr`. +func NewErrorWithStderr(err error, stderr []byte) *ErrorWithStderr { + return &ErrorWithStderr{ + Err: err, + Stderr: stderr, + } +} + +// Error implements `error`. +func (e *ErrorWithStderr) Error() string { + return e.Err.Error() + ": " + string(e.Stderr) +} diff --git a/internal/vpn/os.go b/internal/vpn/os.go index d703645925..916fd8ffc0 100644 --- a/internal/vpn/os.go +++ b/internal/vpn/os.go @@ -1,10 +1,13 @@ package vpn import ( + "bytes" "fmt" + "io" "net" "os" "os/exec" + "strings" ) // LocalNetworkInterfaceIPs gets IPs of all local interfaces. @@ -83,14 +86,19 @@ func parseCIDR(ipCIDR string) (ipStr, netmask string, err error) { //nolint:unparam func run(bin string, args ...string) error { + fullCmd := bin + " " + strings.Join(args, " ") + cmd := exec.Command(bin, args...) //nolint:gosec - cmd.Stderr = os.Stderr + stderrBuf := bytes.NewBuffer(nil) + + cmd.Stderr = io.MultiWriter(os.Stderr, stderrBuf) cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin if err := cmd.Run(); err != nil { - return fmt.Errorf("error running command %s: %w", bin, err) + return NewErrorWithStderr(fmt.Errorf("error running command \"%s\": %w", fullCmd, err), + stderrBuf.Bytes()) } return nil diff --git a/internal/vpn/os_darwin.go b/internal/vpn/os_darwin.go index a06dbc4b78..591e703718 100644 --- a/internal/vpn/os_darwin.go +++ b/internal/vpn/os_darwin.go @@ -3,8 +3,10 @@ package vpn import ( + "errors" "fmt" "strconv" + "strings" ) // SetupTUN sets the allocated TUN interface up, setting its IP, gateway, netmask and MTU. @@ -24,7 +26,16 @@ func AddRoute(ipCIDR, gateway string) error { return fmt.Errorf("error parsing IP CIDR: %w", err) } - return run("route", "add", "-net", ip, gateway, netmask) + err = run("route", "add", "-net", ip, gateway, netmask) + + var e *ErrorWithStderr + if errors.As(err, &e) { + if strings.Contains(string(e.Stderr), "File exists") { + return nil + } + } + + return err } // DeleteRoute removes route to `ipCIDR` through the `gateway` from the OS routing table. diff --git a/internal/vpn/os_linux.go b/internal/vpn/os_linux.go index 6a08a3107c..60817f483f 100644 --- a/internal/vpn/os_linux.go +++ b/internal/vpn/os_linux.go @@ -3,8 +3,10 @@ package vpn import ( + "errors" "fmt" "strconv" + "strings" ) // SetupTUN sets the allocated TUN interface up, setting its IP, gateway, netmask and MTU. @@ -35,7 +37,16 @@ func SetupTUN(ifcName, ipCIDR, gateway string, mtu int) error { // AddRoute adds route to `ip` with `netmask` through the `gateway` to the OS routing table. func AddRoute(ip, gateway string) error { - return run("ip", "r", "add", ip, "via", gateway) + err := run("ip", "r", "add", ip, "via", gateway) + + var e *ErrorWithStderr + if errors.As(err, &e) { + if strings.Contains(string(e.Stderr), "File exists") { + return nil + } + } + + return err } // DeleteRoute removes route to `ip` with `netmask` through the `gateway` from the OS routing table. diff --git a/internal/vpn/os_server.go b/internal/vpn/os_server.go index 8daa14e442..aba342ea7d 100644 --- a/internal/vpn/os_server.go +++ b/internal/vpn/os_server.go @@ -11,13 +11,18 @@ var ( errServerMethodsNotSupported = errors.New("server related methods are not supported for this OS") ) -// AllowSSH allows all SSH traffic (via default 22 port) between `src` and `dst`. -func AllowSSH(_, _ net.IP, _ []net.IP) error { +// GetIPTablesForwardPolicy gets current policy for iptables `forward` chain. +func GetIPTablesForwardPolicy() (string, error) { + return "", errServerMethodsNotSupported +} + +// SetIPTablesForwardPolicy sets `policy` for iptables `forward` chain. +func SetIPTablesForwardPolicy(policy string) error { return errServerMethodsNotSupported } -// BlockSSH blocks all SSH traffic (via default 22 port) between `src` and `dst`. -func BlockSSH(_, _ net.IP, _ []net.IP) error { +// SetIPTablesForwardAcceptPolicy sets ACCEPT policy for iptables `forward` chain. +func SetIPTablesForwardAcceptPolicy() error { return errServerMethodsNotSupported } diff --git a/internal/vpn/os_server_linux.go b/internal/vpn/os_server_linux.go index 0731873575..eaa410b2a1 100644 --- a/internal/vpn/os_server_linux.go +++ b/internal/vpn/os_server_linux.go @@ -7,20 +7,49 @@ import ( "fmt" "net" "os/exec" + "strings" ) const ( - defaultNetworkInterfaceCMD = "ip addr | awk '/state UP/ {print $2}' | sed 's/.$//'" - getIPv4ForwardingCMD = "sysctl net.ipv4.ip_forward" - getIPv6ForwardingCMD = "sysctl net.ipv6.conf.all.forwarding" - setIPv4ForwardingCMDFmt = "sysctl -w net.ipv4.ip_forward=%s" - setIPv6ForwardingCMDFmt = "sysctl -w net.ipv6.conf.all.forwarding=%s" - enableIPMasqueradingCMDFmt = "iptables -t nat -A POSTROUTING -o %s -j MASQUERADE" - disableIPMasqueradingCMDFmt = "iptables -t nat -D POSTROUTING -o %s -j MASQUERADE" - blockIPToLocalNetCMDFmt = "iptables -I FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -I INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP" - allowIPToLocalNetCMDFmt = "iptables -D FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -D INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP" + defaultNetworkInterfaceCMD = "ip addr | awk '/state UP/ {print $2}' | sed 's/.$//'" + getIPv4ForwardingCMD = "sysctl net.ipv4.ip_forward" + getIPv6ForwardingCMD = "sysctl net.ipv6.conf.all.forwarding" + setIPv4ForwardingCMDFmt = "sysctl -w net.ipv4.ip_forward=%s" + setIPv6ForwardingCMDFmt = "sysctl -w net.ipv6.conf.all.forwarding=%s" + getIPTablesForwardPolicyCMD = "iptables -L | grep \"Chain FORWARD\" | tr -d '()' | awk '{print $4}'" + setIPTablesForwardPolicyCMDFmt = "iptables --policy FORWARD %s" + enableIPMasqueradingCMDFmt = "iptables -t nat -A POSTROUTING -o %s -j MASQUERADE" + disableIPMasqueradingCMDFmt = "iptables -t nat -D POSTROUTING -o %s -j MASQUERADE" + blockIPToLocalNetCMDFmt = "iptables -I FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -I INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP" + allowIPToLocalNetCMDFmt = "iptables -D FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -D INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP" ) +// GetIPTablesForwardPolicy gets current policy for iptables `forward` chain. +func GetIPTablesForwardPolicy() (string, error) { + outputBytes, err := exec.Command("sh", "-c", getIPTablesForwardPolicyCMD).Output() + if err != nil { + return "", fmt.Errorf("error running command %s: %w", getIPTablesForwardPolicyCMD, err) + } + + return strings.TrimRight(string(outputBytes), "\n"), nil +} + +// SetIPTablesForwardPolicy sets `policy` for iptables `forward` chain. +func SetIPTablesForwardPolicy(policy string) error { + cmd := fmt.Sprintf(setIPTablesForwardPolicyCMDFmt, policy) + if err := exec.Command("sh", "-c", cmd).Run(); err != nil { //nolint:gosec + return fmt.Errorf("error running command %s: %w", cmd, err) + } + + return nil +} + +// SetIPTablesForwardAcceptPolicy sets ACCEPT policy for iptables `forward` chain. +func SetIPTablesForwardAcceptPolicy() error { + const policy = "ACCEPT" + return SetIPTablesForwardPolicy(policy) +} + // AllowIPToLocalNetwork allows all the packets coming from `source` // to private IP ranges. func AllowIPToLocalNetwork(src, dst net.IP) error { diff --git a/internal/vpn/server.go b/internal/vpn/server.go index 3038b242ea..34c7fb6669 100644 --- a/internal/vpn/server.go +++ b/internal/vpn/server.go @@ -22,6 +22,7 @@ type Server struct { defaultNetworkInterfaceIPs []net.IP ipv4ForwardingVal string ipv6ForwardingVal string + iptablesForwardPolicy string } // NewServer creates VPN server instance. @@ -58,10 +59,18 @@ func NewServer(cfg ServerConfig, l logrus.FieldLogger) (*Server, error) { l.Infoln("Old IP forwarding values:") l.Infof("IPv4: %s, IPv6: %s", ipv4ForwardingVal, ipv6ForwardingVal) + iptablesForwarPolicy, err := GetIPTablesForwardPolicy() + if err != nil { + return nil, fmt.Errorf("error getting iptables forward policy: %w", err) + } + + l.Infof("Old iptables forward policy: %s", iptablesForwarPolicy) + s.defaultNetworkInterface = defaultNetworkIfc s.defaultNetworkInterfaceIPs = defaultNetworkIfcIPs s.ipv4ForwardingVal = ipv4ForwardingVal s.ipv6ForwardingVal = ipv6ForwardingVal + s.iptablesForwardPolicy = iptablesForwarPolicy return s, nil } @@ -111,6 +120,21 @@ func (s *Server) Serve(l net.Listener) error { } }() + if err := SetIPTablesForwardAcceptPolicy(); err != nil { + serveErr = fmt.Errorf("error settings iptables forward policy to ACCEPT") + return + } + + s.log.Infoln("Set iptables forward policy to ACCEPT") + + defer func() { + if err := SetIPTablesForwardPolicy(s.iptablesForwardPolicy); err != nil { + s.log.WithError(err).Errorf("Error setting iptables forward policy to %s", s.iptablesForwardPolicy) + } else { + s.log.Infof("Restored iptables forward policy to %s", s.iptablesForwardPolicy) + } + }() + s.lisMx.Lock() s.lis = l s.lisMx.Unlock()