Skip to content

Commit

Permalink
optimised finding server cert (#4148)
Browse files Browse the repository at this point in the history
* optimised finding server cert

* make sure `close(done)` invoked only once

* remove sleep

* resolve IDE warning

* refactor for findServerCert
  • Loading branch information
qfrank authored Oct 18, 2023
1 parent 0881d8c commit 3326362
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 33 deletions.
7 changes: 7 additions & 0 deletions common/devices.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package common

import "runtime"

const (
AndroidPlatform = "android"
WindowsPlatform = "windows"
)

func OperatingSystemIs(targetOS string) bool {
return runtime.GOOS == targetOS
}
3 changes: 2 additions & 1 deletion protocol/common/shard.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package common

import (
"github.com/waku-org/go-waku/waku/v2/protocol/relay"

"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/transport"
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
)

const MainStatusShardCluster = 16
Expand Down
5 changes: 2 additions & 3 deletions server/ips.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package server

import (
"net"
"runtime"

"go.uber.org/zap"

Expand Down Expand Up @@ -98,7 +97,7 @@ func getAndroidLocalIP() ([][]net.IP, error) {
func getLocalAddresses() ([][]net.IP, error) {
// TODO until we can resolve Android errors when calling net.Interfaces() just return the outbound local address.
// Sorry Android
if runtime.GOOS == common.AndroidPlatform {
if common.OperatingSystemIs(common.AndroidPlatform) {
return getAndroidLocalIP()
}

Expand Down Expand Up @@ -192,7 +191,7 @@ func getAllAvailableNetworks() ([]net.IPNet, error) {
// that returns a reachable server's address to be used by local pairing client.
func FindReachableAddressesForPairingClient(serverIps []net.IP) ([]net.IP, error) {
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
if runtime.GOOS == common.AndroidPlatform {
if common.OperatingSystemIs(common.AndroidPlatform) {
return serverIps, nil
}

Expand Down
3 changes: 2 additions & 1 deletion server/pairing/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ func getServerCert(URL *url.URL) (*x509.Certificate, error) {
InsecureSkipVerify: true, // nolint: gosec // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
}

conn, err := tls.Dial("tcp", URL.Host, conf)
// one second should be enough to get the server's TLS cert in LAN?
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", URL.Host, conf)
if err != nil {
return nil, err
}
Expand Down
77 changes: 52 additions & 25 deletions server/pairing/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import (
"encoding/pem"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
"time"

"go.uber.org/zap"

Expand All @@ -38,31 +38,52 @@ type BaseClient struct {
challengeTaker *ChallengeTaker
}

func findServerCert(c *ConnectionParams) (*url.URL, *x509.Certificate, error) {
netIps, err := server.FindReachableAddressesForPairingClient(c.netIPs)
if err != nil {
return nil, nil, err
}
func findServerCert(c *ConnectionParams, reachableIPs []net.IP) (*url.URL, *x509.Certificate, error) {
var baseAddress *url.URL
var serverCert *x509.Certificate
var certErrs error
for _, ip := range netIps {
u := c.BuildURL(ip)

serverCert, err = getServerCert(u)
if err != nil {
var certErr string
if certErrs != nil {
certErr = certErrs.Error()

type connectionError struct {
ip net.IP
err error
}
errCh := make(chan connectionError, len(reachableIPs))

type result struct {
u *url.URL
cert *x509.Certificate
}
successCh := make(chan result, 1) // as we close on the first success

for _, ip := range reachableIPs {
go func(ip net.IP) {
u := c.BuildURL(ip)
cert, err := getServerCert(u)
if err != nil {
errCh <- connectionError{ip: ip, err: fmt.Errorf("connecting to '%s' failed: %s", u, err.Error())}
return
}
// If no error, send the results to the success channel
successCh <- result{u: u, cert: cert}
}(ip)
}

// Keep track of error counts
errorCount := 0
var combinedErrors string
for {
select {
case success := <-successCh:
baseAddress = success.u
serverCert = success.cert
return baseAddress, serverCert, nil
case ipErr := <-errCh:
errorCount++
combinedErrors += fmt.Sprintf("IP %s: %s; ", ipErr.ip, ipErr.err)
if errorCount == len(reachableIPs) {
return nil, nil, fmt.Errorf(combinedErrors)
}
certErrs = fmt.Errorf("%sconnecting to '%s' failed: %s; ", certErr, u, err.Error())
continue
}

baseAddress = u
break
}
return baseAddress, serverCert, certErrs
}

// NewBaseClient returns a fully qualified BaseClient from the given ConnectionParams
Expand All @@ -71,13 +92,19 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
var serverCert *x509.Certificate
var certErrs error

netIPs, err := server.FindReachableAddressesForPairingClient(c.netIPs)
if err != nil {
logger.Error("[local pair client] failed to find reachable addresses", zap.Error(err), zap.Any("netIPs", netIPs))
signal.SendLocalPairingEvent(Event{Type: EventConnectionError, Error: err.Error(), Action: ActionConnect})
return nil, err
}

maxRetries := 3
for i := 0; i < maxRetries; i++ {
baseAddress, serverCert, certErrs = findServerCert(c)
baseAddress, serverCert, certErrs = findServerCert(c, netIPs)
if serverCert == nil {
certErrs = fmt.Errorf("failed to connect to any of given addresses. %w", certErrs)
time.Sleep(1 * time.Second)
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs))
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs), zap.Any("netIPs", netIPs), zap.Int("retry", i+1))
} else {
break
}
Expand All @@ -92,7 +119,7 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
// No error on the dial out then the URL.Host is accessible
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect})

err := verifyCert(serverCert, c.publicKey)
err = verifyCert(serverCert, c.publicKey)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion server/pairing/peers.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (p *PeerNotifier) handler(hello *peers.LocalPairingPeerHello) {

func (p *PeerNotifier) Search() error {
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
if runtime.GOOS == common.AndroidPlatform {
if common.OperatingSystemIs(common.AndroidPlatform) {
return nil
}

Expand Down
5 changes: 3 additions & 2 deletions t/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import (
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
"time"

"github.com/status-im/status-go/common"

_ "github.com/stretchr/testify/suite" // required to register testify flags

"github.com/status-im/status-go/logutils"
Expand Down Expand Up @@ -222,7 +223,7 @@ func WaitClosed(c <-chan struct{}, d time.Duration) error {
func MakeTestNodeConfig(networkID int) (*params.NodeConfig, error) {
testDir := filepath.Join(TestDataDir, TestNetworkNames[networkID])

if runtime.GOOS == "windows" {
if common.OperatingSystemIs(common.WindowsPlatform) {
testDir = filepath.ToSlash(testDir)
}

Expand Down

0 comments on commit 3326362

Please sign in to comment.