Skip to content

Commit

Permalink
chore(portforward): remove PIA dependency on storage package
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed May 2, 2024
1 parent e0a977c commit 6dd27e5
Show file tree
Hide file tree
Showing 18 changed files with 63 additions and 98 deletions.
9 changes: 6 additions & 3 deletions internal/models/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@ type Connection struct {
// Hostname is used for IPVanish, IVPN, Privado
// and Windscribe for TLS verification.
Hostname string `json:"hostname"`
// ServerName is used for PIA for port forwarding
ServerName string `json:"server_name,omitempty"`
// PubKey is the public key of the VPN server,
// used only for Wireguard.
PubKey string `json:"pubkey"`
// ServerName is used for PIA for port forwarding
ServerName string `json:"server_name,omitempty"`
// PortForward is used for PIA for port forwarding
PortForward bool `json:"port_forward"`
}

func (c *Connection) Equal(other Connection) bool {
return c.IP.Compare(other.IP) == 0 && c.Port == other.Port &&
c.Protocol == other.Protocol && c.Hostname == other.Hostname &&
c.ServerName == other.ServerName && c.PubKey == other.PubKey
c.PubKey == other.PubKey && c.ServerName == other.ServerName &&
c.PortForward == other.PortForward
}

// UpdateEmptyWith updates each field of the connection where the
Expand Down
15 changes: 9 additions & 6 deletions internal/portforward/service/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
)

type Settings struct {
Enabled *bool
PortForwarder PortForwarder
Filepath string
Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA
ListeningPort uint16
Enabled *bool
PortForwarder PortForwarder
Filepath string
Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA
CanPortForward bool // needed for PIA
ListeningPort uint16
}

func (s Settings) Copy() (copied Settings) {
Expand All @@ -23,6 +24,7 @@ func (s Settings) Copy() (copied Settings) {
copied.Filepath = s.Filepath
copied.Interface = s.Interface
copied.ServerName = s.ServerName
copied.CanPortForward = s.CanPortForward
copied.ListeningPort = s.ListeningPort
return copied
}
Expand All @@ -33,6 +35,7 @@ func (s *Settings) OverrideWith(update Settings) {
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort)
}

Expand Down
9 changes: 5 additions & 4 deletions internal/portforward/service/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
}

obj := utils.PortForwardObjects{
Logger: s.logger,
Gateway: gateway,
Client: s.client,
ServerName: s.settings.ServerName,
Logger: s.logger,
Gateway: gateway,
Client: s.client,
ServerName: s.settings.ServerName,
CanPortForward: s.settings.CanPortForward,
}
port, err := s.settings.PortForwarder.PortForward(ctx, obj)
if err != nil {
Expand Down
15 changes: 0 additions & 15 deletions internal/provider/common/mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion internal/provider/common/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ import (
type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}
2 changes: 2 additions & 0 deletions internal/provider/custom/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func getOpenVPNConnection(extractor Extractor,
// Set the server name for PIA port forwarding code used
// together with the custom provider.
connection.ServerName = selection.Names[0]
connection.PortForward = true
}

return connection, nil
Expand All @@ -62,6 +63,7 @@ func getWireguardConnection(selection settings.ServerSelection) (
// Set the server name for PIA port forwarding code used
// together with the custom provider.
connection.ServerName = selection.Names[0]
connection.PortForward = true
}
return connection
}
11 changes: 2 additions & 9 deletions internal/provider/privateinternetaccess/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"strings"
"time"

"github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/golibs/format"
)
Expand All @@ -37,16 +36,10 @@ func (p *Provider) PortForward(ctx context.Context,

serverName := objects.ServerName

server, ok := p.storage.GetServerByName(providers.PrivateInternetAccess, serverName)
if !ok {
return 0, fmt.Errorf("%w: %s", ErrServerNameNotFound, serverName)
}

logger := objects.Logger

if !server.PortForward {
logger.Error("The server " + serverName +
" (region " + server.Region + ") does not support port forwarding")
if !objects.CanPortForward {
logger.Error("The server " + serverName + " does not support port forwarding")
return 0, nil
}

Expand Down
1 change: 0 additions & 1 deletion internal/provider/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ type Providers struct {
type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (
servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}

type Extractor interface {
Expand Down
15 changes: 8 additions & 7 deletions internal/provider/utils/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ func GetConnection(provider string,
}

connection := models.Connection{
Type: selection.VPN,
IP: ip,
Port: port,
Protocol: protocol,
Hostname: hostname,
ServerName: server.ServerName,
PubKey: server.WgPubKey, // Wireguard
Type: selection.VPN,
IP: ip,
Port: port,
Protocol: protocol,
Hostname: hostname,
ServerName: server.ServerName,
PortForward: server.PortForward,
PubKey: server.WgPubKey, // Wireguard
}
connections = append(connections, connection)
}
Expand Down
7 changes: 3 additions & 4 deletions internal/provider/utils/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ type PortForwardObjects struct {
Gateway netip.Addr
// Client is used to query the VPN gateway for Private Internet Access.
Client *http.Client
// ServerName is used by Private Internet Access for port forwarding,
// and to look up the server data from storage.
// TODO use server data directly to remove storage dependency for port
// forwarding implementation.
// ServerName is used by Private Internet Access for port forwarding.
ServerName string
// CanPortForward is used by Private Internet Access for port forwarding.
CanPortForward bool
}

type Routing interface {
Expand Down
23 changes: 0 additions & 23 deletions internal/storage/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,6 @@ func (s *Storage) SetServers(provider string, servers []models.Server) (err erro
return nil
}

// GetServerByName returns the server for the given provider
// and server name. It returns `ok` as false if the server is
// not found. The returned server is also deep copied so it is
// safe for mutation and/or thread safe use.
func (s *Storage) GetServerByName(provider, name string) (
server models.Server, ok bool) {
if provider == providers.Custom {
return server, false
}

s.mergedMutex.RLock()
defer s.mergedMutex.RUnlock()

serversObject := s.getMergedServersObject(provider)
for _, server := range serversObject.Servers {
if server.ServerName == name {
return copyServer(server), true
}
}

return server, false
}

// GetServersCount returns the number of servers for the provider given.
func (s *Storage) GetServersCount(provider string) (count int) {
if provider == providers.Custom {
Expand Down
1 change: 0 additions & 1 deletion internal/updater/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ type Storage interface {
ServersAreEqual(provider string, servers []models.Server) (equal bool)
// Extra methods to match the provider.New storage interface
FilterServers(provider string, selection settings.ServerSelection) (filtered []models.Server, err error)
GetServerByName(provider string, name string) (server models.Server, ok bool)
}

type Unzipper interface {
Expand Down
1 change: 0 additions & 1 deletion internal/vpn/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ type PortForwarder interface {

type Storage interface {
FilterServers(provider string, selection settings.ServerSelection) (servers []models.Server, err error)
GetServerByName(provider, name string) (server models.Server, ok bool)
}

type NetLinker interface {
Expand Down
15 changes: 8 additions & 7 deletions internal/vpn/openvpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,38 @@ import (
func setupOpenVPN(ctx context.Context, fw Firewall,
openvpnConf OpenVPN, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, starter command.Starter,
logger openvpn.Logger) (runner *openvpn.Runner, serverName string, err error) {
logger openvpn.Logger) (runner *openvpn.Runner, serverName string,
canPortForward bool, err error) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, "", fmt.Errorf("finding a valid server connection: %w", err)
return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err)
}

lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)

if err := openvpnConf.WriteConfig(lines); err != nil {
return nil, "", fmt.Errorf("writing configuration to file: %w", err)
return nil, "", false, fmt.Errorf("writing configuration to file: %w", err)
}

if *settings.OpenVPN.User != "" {
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
if err != nil {
return nil, "", fmt.Errorf("writing auth to file: %w", err)
return nil, "", false, fmt.Errorf("writing auth to file: %w", err)
}
}

if *settings.OpenVPN.KeyPassphrase != "" {
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
if err != nil {
return nil, "", fmt.Errorf("writing askpass file: %w", err)
return nil, "", false, fmt.Errorf("writing askpass file: %w", err)
}
}

if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
return nil, "", fmt.Errorf("allowing VPN connection through firewall: %w", err)
return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err)
}

runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)

return runner, connection.ServerName, nil
return runner, connection.ServerName, connection.PortForward, nil
}
7 changes: 4 additions & 3 deletions internal/vpn/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ func (l *Loop) startPortForwarding(data tunnelUpData) (err error) {
partialUpdate := portforward.Settings{
VPNIsUp: ptrTo(true),
Service: service.Settings{
PortForwarder: data.portForwarder,
Interface: data.vpnIntf,
ServerName: data.serverName,
PortForwarder: data.portForwarder,
Interface: data.vpnIntf,
ServerName: data.serverName,
CanPortForward: data.canPortForward,
},
}
return l.portForward.UpdateWith(partialUpdate)
Expand Down
12 changes: 7 additions & 5 deletions internal/vpn/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,27 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
}
var serverName, vpnInterface string
var canPortForward bool
var err error
subLogger := l.logger.New(log.SetComponent(settings.Type))
if settings.Type == vpn.OpenVPN {
vpnInterface = settings.OpenVPN.Interface
vpnRunner, serverName, err = setupOpenVPN(ctx, l.fw,
vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard
vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, err = setupWireguard(ctx, l.netLinker, l.fw,
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
providerConf, settings, l.ipv6Supported, subLogger)
}
if err != nil {
l.crashed(ctx, err)
continue
}
tunnelUpData := tunnelUpData{
serverName: serverName,
portForwarder: portForwarder,
vpnIntf: vpnInterface,
serverName: serverName,
canPortForward: canPortForward,
portForwarder: portForwarder,
vpnIntf: vpnInterface,
}

openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
Expand Down
7 changes: 4 additions & 3 deletions internal/vpn/tunnelup.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (

type tunnelUpData struct {
// Port forwarding
vpnIntf string
serverName string
portForwarder PortForwarder
vpnIntf string
serverName string // used for PIA
canPortForward bool // used for PIA
portForwarder PortForwarder
}

func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
Expand Down
10 changes: 5 additions & 5 deletions internal/vpn/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ import (
func setupWireguard(ctx context.Context, netlinker NetLinker,
fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, serverName string, err error) {
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, "", fmt.Errorf("finding a VPN server: %w", err)
return nil, "", false, fmt.Errorf("finding a VPN server: %w", err)
}

wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
Expand All @@ -30,13 +30,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,

wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
if err != nil {
return nil, "", fmt.Errorf("creating Wireguard: %w", err)
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
}

err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil {
return nil, "", fmt.Errorf("setting firewall: %w", err)
return nil, "", false, fmt.Errorf("setting firewall: %w", err)
}

return wireguarder, connection.ServerName, nil
return wireguarder, connection.ServerName, connection.PortForward, nil
}

0 comments on commit 6dd27e5

Please sign in to comment.