Skip to content

Commit

Permalink
Merge pull request #813 from i-hate-nicknames/fix/stcpr-redial-loop
Browse files Browse the repository at this point in the history
Fix stcpr transport establishment issues
  • Loading branch information
jdknives authored Jun 21, 2021
2 parents 7299c7e + 6ab08cc commit e23328e
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 60 deletions.
32 changes: 27 additions & 5 deletions pkg/snet/directtp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/skycoin/skywire/pkg/snet/directtp/tphandshake"
"github.com/skycoin/skywire/pkg/snet/directtp/tplistener"
"github.com/skycoin/skywire/pkg/snet/directtp/tptypes"
"github.com/skycoin/skywire/pkg/util/netutil"
)

const (
Expand Down Expand Up @@ -136,7 +137,14 @@ func (c *client) Serve() error {
c.log.Errorf("Failed to extract port from addr %v: %v", err)
return
}

hasPublic, err := netutil.HasPublicIP()
if err != nil {
c.log.Errorf("Failed to check for public IP: %v", err)
}
if !hasPublic {
c.log.Infof("Not binding STCPR: no public IP address found")
return
}
if err := c.conf.AddressResolver.BindSTCPR(context.Background(), port); err != nil {
c.log.Errorf("Failed to bind STCPR: %v", err)
return
Expand Down Expand Up @@ -265,7 +273,7 @@ func (c *client) Dial(ctx context.Context, rPK cipher.PubKey, rPort uint16) (*tp

c.log.Infof("Resolved PK %v to visor data %v", rPK, visorData)

conn, err := c.dialVisor(visorData)
conn, err := c.dialVisor(ctx, visorData)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -313,6 +321,20 @@ func (c *client) dial(addr string) (net.Conn, error) {
}
}

func (c *client) dialContext(ctx context.Context, addr string) (net.Conn, error) {
dialer := net.Dialer{}
switch c.conf.Type {
case tptypes.STCP, tptypes.STCPR:
return dialer.DialContext(ctx, "tcp", addr)

case tptypes.SUDPH:
return c.dialUDPWithTimeout(addr)

default:
return nil, ErrUnknownTransportType
}
}

func (c *client) listen(addr string) (net.Listener, error) {
switch c.conf.Type {
case tptypes.STCP, tptypes.STCPR:
Expand Down Expand Up @@ -405,7 +427,7 @@ func (c *client) dialUDPWithTimeout(addr string) (net.Conn, error) {
}
}

func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) {
func (c *client) dialVisor(ctx context.Context, visorData arclient.VisorData) (net.Conn, error) {
if visorData.IsLocal {
for _, host := range visorData.Addresses {
addr := net.JoinHostPort(host, visorData.Port)
Expand All @@ -416,7 +438,7 @@ func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) {
}
}

conn, err := c.dial(addr)
conn, err := c.dialContext(ctx, addr)
if err == nil {
return conn, nil
}
Expand All @@ -434,7 +456,7 @@ func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) {
}
}

return c.dial(addr)
return c.dialContext(ctx, addr)
}

// Listen creates a new listener for sudp.
Expand Down
23 changes: 10 additions & 13 deletions pkg/transport/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ type Entry struct {
// ID is the Transport ID that uniquely identifies the Transport.
ID uuid.UUID `json:"t_id"`

// Edges contains the public keys of the Transport's edge nodes (should only have 2 edges and the least-significant edge should come first).
// Edges contains the public keys of the Transport's edge nodes
// (should only have 2 edges and the first edge is transport original initiator).
Edges [2]cipher.PubKey `json:"edges"`

// Type represents the transport type.
Expand All @@ -47,20 +48,16 @@ type Entry struct {
}

// MakeEntry creates a new transport entry
func MakeEntry(pk1, pk2 cipher.PubKey, tpType string, public bool, label Label) Entry {
return Entry{
ID: MakeTransportID(pk1, pk2, tpType),
Edges: SortEdges(pk1, pk2),
func MakeEntry(initiator, target cipher.PubKey, tpType string, public bool, label Label) Entry {
entry := Entry{
ID: MakeTransportID(initiator, target, tpType),
Type: tpType,
Public: public,
Label: label,
}
}

// SetEdges sets edges of Entry
func (e *Entry) SetEdges(localPK, remotePK cipher.PubKey) {
e.ID = MakeTransportID(localPK, remotePK, e.Type)
e.Edges = SortEdges(localPK, remotePK)
entry.Edges[0] = initiator
entry.Edges[1] = target
return entry
}

// RemoteEdge returns the remote edge's public key.
Expand Down Expand Up @@ -106,8 +103,8 @@ func (e *Entry) String() string {
res += fmt.Sprintf("\ttype: %s\n", e.Type)
res += fmt.Sprintf("\tid: %s\n", e.ID)
res += "\tedges:\n"
res += fmt.Sprintf("\t\tedge 1: %s\n", e.Edges[0])
res += fmt.Sprintf("\t\tedge 2: %s\n", e.Edges[1])
res += fmt.Sprintf("\t\tedge 1 (initiator): %s\n", e.Edges[0])
res += fmt.Sprintf("\t\tedge 2 (target): %s\n", e.Edges[1])
return res
}

Expand Down
15 changes: 0 additions & 15 deletions pkg/transport/entry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,6 @@ func TestNewEntry(t *testing.T) {
assert.NotNil(t, entryBA.ID)
}

func TestEntry_SetEdges(t *testing.T) {
pkA, _ := cipher.GenerateKeyPair()
pkB, _ := cipher.GenerateKeyPair()

entryAB, entryBA := transport.Entry{}, transport.Entry{}

entryAB.SetEdges(pkA, pkB)
entryBA.SetEdges(pkA, pkB)

assert.True(t, entryAB.Edges == entryBA.Edges)
assert.True(t, entryAB.ID == entryBA.ID)
assert.NotNil(t, entryAB.ID)
assert.NotNil(t, entryBA.ID)
}

func ExampleSignedEntry_Sign() {
pkA, skA := cipher.GenerateKeyPair()
pkB, skB := cipher.GenerateKeyPair()
Expand Down
49 changes: 37 additions & 12 deletions pkg/transport/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,21 @@ import (
"github.com/skycoin/skywire/pkg/snet"
)

func makeEntryFromTpConn(conn *snet.Conn) Entry {
return MakeEntry(conn.LocalPK(), conn.RemotePK(), conn.Network(), true, LabelUser)
type hsResponse byte

const (
responseFailure hsResponse = iota
responseOK
responseSignatureErr
responseInvalidEntry
)

func makeEntryFromTpConn(conn *snet.Conn, isInitiator bool) Entry {
initiator, target := conn.LocalPK(), conn.RemotePK()
if !isInitiator {
initiator, target = target, initiator
}
return MakeEntry(initiator, target, conn.Network(), true, LabelUser)
}

func compareEntries(expected, received *Entry) error {
Expand Down Expand Up @@ -86,7 +99,7 @@ func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, conn *snet.Co
func MakeSettlementHS(init bool) SettlementHS {
// initiating logic.
initHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) (err error) {
entry := makeEntryFromTpConn(conn)
entry := makeEntryFromTpConn(conn, true)

// TODO(evanlinjin): Probably not needed as this is called in mTp already. Need to double check.
//defer func() {
Expand All @@ -110,23 +123,33 @@ func MakeSettlementHS(init bool) SettlementHS {
if _, err := io.ReadFull(conn, accepted); err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if accepted[0] == 0 {
switch hsResponse(accepted[0]) {
case responseOK:
return nil
case responseFailure:
return fmt.Errorf("transport settlement rejected by remote")
case responseInvalidEntry:
return fmt.Errorf("invalid entry")
case responseSignatureErr:
return fmt.Errorf("signature error")
default:
return fmt.Errorf("invalid remote response")
}
return nil
}

// responding logic.
respHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) error {
entry := makeEntryFromTpConn(conn)
entry := makeEntryFromTpConn(conn, false)

// receive, verify and sign entry.
recvSE, err := receiveAndVerifyEntry(conn, &entry, conn.RemotePK())
if err != nil {
writeHsResponse(conn, responseInvalidEntry) //nolint:errcheck, gosec
return err
}

if err := recvSE.Sign(conn.LocalPK(), sk); err != nil {
writeHsResponse(conn, responseSignatureErr) //nolint:errcheck, gosec
return fmt.Errorf("failed to sign received entry: %w", err)
}

Expand All @@ -141,16 +164,18 @@ func MakeSettlementHS(init bool) SettlementHS {
log.WithError(err).Error("Failed to register transport.")
}
}

// inform initiating visor.
if _, err := conn.Write([]byte{1}); err != nil {
return fmt.Errorf("failed to accept transport settlement: write failed: %w", err)
}
return nil
return writeHsResponse(conn, responseOK)
}

if init {
return initHS
}
return respHS
}

func writeHsResponse(w io.Writer, response hsResponse) error {
if _, err := w.Write([]byte{byte(response)}); err != nil {
return fmt.Errorf("failed to accept transport settlement: write failed: %w", err)
}
return nil
}
30 changes: 21 additions & 9 deletions pkg/transport/managed_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ var (

// Constants associated with transport redial loop.
const (
tpInitBO = time.Millisecond * 500
tpMaxBO = time.Minute
tpTries = 0
tpFactor = 2
tpInitBO = time.Millisecond * 500
tpMaxBO = time.Minute
tpTries = 0
tpFactor = 2
tpTimeout = time.Second * 3 // timeout for a single try
)

// ManagedTransportConfig is a configuration for managed transport.
Expand Down Expand Up @@ -92,15 +93,19 @@ type ManagedTransport struct {
}

// NewManagedTransport creates a new ManagedTransport.
func NewManagedTransport(conf ManagedTransportConfig) *ManagedTransport {
func NewManagedTransport(conf ManagedTransportConfig, isInitiator bool) *ManagedTransport {
initiator, target := conf.Net.LocalPK(), conf.RemotePK
if !isInitiator {
initiator, target = target, initiator
}
mt := &ManagedTransport{
log: logging.MustGetLogger(fmt.Sprintf("tp:%s", conf.RemotePK.String()[:6])),
rPK: conf.RemotePK,
netName: conf.NetName,
n: conf.Net,
dc: conf.DC,
ls: conf.LS,
Entry: MakeEntry(conf.Net.LocalPK(), conf.RemotePK, conf.NetName, true, conf.TransportLabel),
Entry: MakeEntry(initiator, target, conf.NetName, true, conf.TransportLabel),
LogEntry: new(LogEntry),
connCh: make(chan struct{}, 1),
done: make(chan struct{}),
Expand Down Expand Up @@ -204,8 +209,8 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) {
continue
}

// Only least significant edge is responsible for redialing.
if !mt.isLeastSignificantEdge() {
// Only initiator is responsible for redialing.
if !mt.isInitiator() {
continue
}

Expand Down Expand Up @@ -371,16 +376,23 @@ func (mt *ManagedTransport) redialLoop(ctx context.Context) error {

// Only redial when there is no underlying conn.
return retry.Do(ctx, func() (err error) {
tryCtx, cancel := context.WithTimeout(ctx, tpTimeout)
defer cancel()
mt.connMx.Lock()
if mt.conn == nil {
err = mt.redial(ctx)
err = mt.redial(tryCtx)
}
mt.connMx.Unlock()
return err
})
}

func (mt *ManagedTransport) isLeastSignificantEdge() bool {
sorted := SortEdges(mt.Entry.Edges[0], mt.Entry.Edges[1])
return sorted[0] == mt.n.LocalPK()
}

func (mt *ManagedTransport) isInitiator() bool {
return mt.Entry.EdgeIndex(mt.n.LocalPK()) == 0
}

Expand Down
11 changes: 6 additions & 5 deletions pkg/transport/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ func (tm *Manager) initTransports(ctx context.Context) {
remote = entry.Entry.RemoteEdge(tm.Conf.PubKey)
tpID = entry.Entry.ID
)
if _, err := tm.saveTransport(remote, tpType, entry.Entry.Label); err != nil {
isInitiator := tm.n.LocalPK() == entry.Entry.Edges[0]
if _, err := tm.saveTransport(remote, isInitiator, tpType, entry.Entry.Label); err != nil {
tm.Logger.Warnf("INIT: failed to init tp: type(%s) remote(%s) tpID(%s)", tpType, remote, tpID)
} else {
tm.Logger.Debugf("Successfully initialized TP %v", *entry.Entry)
Expand Down Expand Up @@ -230,7 +231,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) erro
NetName: lis.Network(),
AfterClosed: tm.afterTPClosed,
TransportLabel: LabelUser,
})
}, false)

go func() {
mTp.Serve(tm.readCh)
Expand Down Expand Up @@ -303,7 +304,7 @@ func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, tpTy
}

for {
mTp, err := tm.saveTransport(remote, tpType, label)
mTp, err := tm.saveTransport(remote, true, tpType, label)
if err != nil {
return nil, fmt.Errorf("save transport: %w", err)
}
Expand Down Expand Up @@ -346,7 +347,7 @@ func isSTCPTableError(remotePK cipher.PubKey, err error) bool {
return err.Error() == fmt.Sprintf("pk table: entry of %s does not exist", remotePK.String())
}

func (tm *Manager) saveTransport(remote cipher.PubKey, netName string, label Label) (*ManagedTransport, error) {
func (tm *Manager) saveTransport(remote cipher.PubKey, initiator bool, netName string, label Label) (*ManagedTransport, error) {
tm.mx.Lock()
defer tm.mx.Unlock()
if !snet.IsKnownNetwork(netName) {
Expand All @@ -372,7 +373,7 @@ func (tm *Manager) saveTransport(remote cipher.PubKey, netName string, label Lab
NetName: netName,
AfterClosed: afterTPClosed,
TransportLabel: label,
})
}, initiator)

if mTp.netName == tptypes.STCPR {
ar := mTp.n.Conf().ARClient
Expand Down
3 changes: 2 additions & 1 deletion pkg/transport/tpdclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func newTestEntry() *transport.Entry {
Type: "dmsg",
Public: true,
}
entry.SetEdges(pk1, testPubKey)
entry.Edges[0] = pk1
entry.Edges[1] = testPubKey

return entry
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/util/netutil/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,18 @@ func DefaultNetworkInterfaceIPs() ([]net.IP, error) {
}
return localIPs, nil
}

// HasPublicIP returns true if this machine has at least one
// publically available IP address
func HasPublicIP() (bool, error) {
localIPs, err := LocalNetworkInterfaceIPs()
if err != nil {
return false, err
}
for _, IP := range localIPs {
if IsPublicIP(IP) {
return true, nil
}
}
return false, nil
}

0 comments on commit e23328e

Please sign in to comment.