Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix VPN server #616

Merged
merged 7 commits into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions internal/vpn/os_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@ var (
errServerMethodsNotSupported = errors.New("server related methods are not supported for this OS")
)

// AllowSSH allows all SSH traffic (via default 22 port) between `src` and `dst`.
func AllowSSH(_, _ net.IP, _ []net.IP) error {
// GetIPTablesForwardPolicy gets current policy for iptables `forward` chain.
func GetIPTablesForwardPolicy() (string, error) {
return "", errServerMethodsNotSupported
}

// SetIPTablesForwardPolicy sets `policy` for iptables `forward` chain.
func SetIPTablesForwardPolicy(policy string) error {
return errServerMethodsNotSupported
}

// BlockSSH blocks all SSH traffic (via default 22 port) between `src` and `dst`.
func BlockSSH(_, _ net.IP, _ []net.IP) error {
// SetIPTablesForwardAcceptPolicy sets ACCEPT policy for iptables `forward` chain.
func SetIPTablesForwardAcceptPolicy() error {
return errServerMethodsNotSupported
}

Expand Down
47 changes: 38 additions & 9 deletions internal/vpn/os_server_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,49 @@ import (
"fmt"
"net"
"os/exec"
"strings"
)

const (
defaultNetworkInterfaceCMD = "ip addr | awk '/state UP/ {print $2}' | sed 's/.$//'"
getIPv4ForwardingCMD = "sysctl net.ipv4.ip_forward"
getIPv6ForwardingCMD = "sysctl net.ipv6.conf.all.forwarding"
setIPv4ForwardingCMDFmt = "sysctl -w net.ipv4.ip_forward=%s"
setIPv6ForwardingCMDFmt = "sysctl -w net.ipv6.conf.all.forwarding=%s"
enableIPMasqueradingCMDFmt = "iptables -t nat -A POSTROUTING -o %s -j MASQUERADE"
disableIPMasqueradingCMDFmt = "iptables -t nat -D POSTROUTING -o %s -j MASQUERADE"
blockIPToLocalNetCMDFmt = "iptables -I FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -I INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP"
allowIPToLocalNetCMDFmt = "iptables -D FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -D INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP"
defaultNetworkInterfaceCMD = "ip addr | awk '/state UP/ {print $2}' | sed 's/.$//'"
getIPv4ForwardingCMD = "sysctl net.ipv4.ip_forward"
getIPv6ForwardingCMD = "sysctl net.ipv6.conf.all.forwarding"
setIPv4ForwardingCMDFmt = "sysctl -w net.ipv4.ip_forward=%s"
setIPv6ForwardingCMDFmt = "sysctl -w net.ipv6.conf.all.forwarding=%s"
getIPTablesForwardPolicyCMD = "iptables -L | grep \"Chain FORWARD\" | tr -d '()' | awk '{print $4}'"
setIPTablesForwardPolicyCMDFmt = "iptables --policy FORWARD %s"
enableIPMasqueradingCMDFmt = "iptables -t nat -A POSTROUTING -o %s -j MASQUERADE"
disableIPMasqueradingCMDFmt = "iptables -t nat -D POSTROUTING -o %s -j MASQUERADE"
blockIPToLocalNetCMDFmt = "iptables -I FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -I INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP"
allowIPToLocalNetCMDFmt = "iptables -D FORWARD -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP && iptables -D INPUT -d 192.168.0.0/16,172.16.0.0/12,10.0.0.0/8 -s %s -j DROP"
)

// GetIPTablesForwardPolicy gets current policy for iptables `forward` chain.
func GetIPTablesForwardPolicy() (string, error) {
outputBytes, err := exec.Command("sh", "-c", getIPTablesForwardPolicyCMD).Output()
if err != nil {
return "", fmt.Errorf("error running command %s: %w", getIPTablesForwardPolicyCMD, err)
}

return strings.TrimRight(string(outputBytes), "\n"), nil
}

// SetIPTablesForwardPolicy sets `policy` for iptables `forward` chain.
func SetIPTablesForwardPolicy(policy string) error {
cmd := fmt.Sprintf(setIPTablesForwardPolicyCMDFmt, policy)
if err := exec.Command("sh", "-c", cmd).Run(); err != nil { //nolint:gosec
return fmt.Errorf("error running command %s: %w", cmd, err)
}

return nil
}

// SetIPTablesForwardAcceptPolicy sets ACCEPT policy for iptables `forward` chain.
func SetIPTablesForwardAcceptPolicy() error {
const policy = "ACCEPT"
return SetIPTablesForwardPolicy(policy)
}

// AllowIPToLocalNetwork allows all the packets coming from `source`
// to private IP ranges.
func AllowIPToLocalNetwork(src, dst net.IP) error {
Expand Down
24 changes: 24 additions & 0 deletions internal/vpn/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Server struct {
defaultNetworkInterfaceIPs []net.IP
ipv4ForwardingVal string
ipv6ForwardingVal string
iptablesForwardPolicy string
}

// NewServer creates VPN server instance.
Expand Down Expand Up @@ -58,10 +59,18 @@ func NewServer(cfg ServerConfig, l logrus.FieldLogger) (*Server, error) {
l.Infoln("Old IP forwarding values:")
l.Infof("IPv4: %s, IPv6: %s", ipv4ForwardingVal, ipv6ForwardingVal)

iptablesForwarPolicy, err := GetIPTablesForwardPolicy()
if err != nil {
return nil, fmt.Errorf("error getting iptables forward policy: %w", err)
}

l.Infof("Old iptables forward policy: %s", iptablesForwarPolicy)

s.defaultNetworkInterface = defaultNetworkIfc
s.defaultNetworkInterfaceIPs = defaultNetworkIfcIPs
s.ipv4ForwardingVal = ipv4ForwardingVal
s.ipv6ForwardingVal = ipv6ForwardingVal
s.iptablesForwardPolicy = iptablesForwarPolicy

return s, nil
}
Expand Down Expand Up @@ -111,6 +120,21 @@ func (s *Server) Serve(l net.Listener) error {
}
}()

if err := SetIPTablesForwardAcceptPolicy(); err != nil {
serveErr = fmt.Errorf("error settings iptables forward policy to ACCEPT")
return
}

s.log.Infoln("Set iptables forward policy to ACCEPT")

defer func() {
if err := SetIPTablesForwardPolicy(s.iptablesForwardPolicy); err != nil {
s.log.WithError(err).Errorf("Error setting iptables forward policy to %s", s.iptablesForwardPolicy)
} else {
s.log.Infof("Restored iptables forward policy to %s", s.iptablesForwardPolicy)
}
}()

s.lisMx.Lock()
s.lis = l
s.lisMx.Unlock()
Expand Down