Skip to content

Commit

Permalink
only start hole punching service after the host has a public address
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Nov 9, 2021
1 parent 7388d1f commit 4e776e5
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 23 deletions.
75 changes: 61 additions & 14 deletions p2p/protocol/holepunch/coordination.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"sync"
"time"

"github.com/libp2p/go-libp2p-core/event"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
Expand Down Expand Up @@ -56,6 +57,8 @@ type Service struct {
closed bool
refCount sync.WaitGroup

hasPublicAddrsChan chan struct{} // this chan is closed as soon as we have a public address

// active hole punches for deduplicating
activeMx sync.Mutex
active map[peer.ID]struct{}
Expand All @@ -71,11 +74,12 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service,

ctx, cancel := context.WithCancel(context.Background())
hs := &Service{
ctx: ctx,
ctxCancel: cancel,
host: h,
ids: ids,
active: make(map[peer.ID]struct{}),
ctx: ctx,
ctxCancel: cancel,
host: h,
ids: ids,
active: make(map[peer.ID]struct{}),
hasPublicAddrsChan: make(chan struct{}),
}

for _, opt := range opts {
Expand All @@ -85,11 +89,42 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service,
}
}

h.SetStreamHandler(Protocol, hs.handleNewStream)
sub, err := hs.host.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{})
if err != nil {
return nil, err
}
hs.refCount.Add(1)
go hs.watchForPublicAddr(sub)

h.Network().Notify((*netNotifiee)(hs))
return hs, nil
}

func (hs *Service) watchForPublicAddr(sub event.Subscription) {
defer hs.refCount.Done()
defer sub.Close()

log.Debug("waiting until we have at least one public address", "peer", hs.host.ID())

for {
if containsPublicAddr(hs.ids.OwnObservedAddrs()) {
log.Debug("Host now has a public address. Starting holepunch protocol.")
hs.host.SetStreamHandler(Protocol, hs.handleNewStream)
close(hs.hasPublicAddrsChan)
return
}

select {
case <-hs.ctx.Done():
return
case _, ok := <-sub.Out():
if !ok {
return
}
}
}
}

// Close closes the Hole Punch Service.
func (hs *Service) Close() error {
hs.closeMx.Lock()
Expand Down Expand Up @@ -176,7 +211,6 @@ func (hs *Service) beginDirectConnect(p peer.ID) error {
// It first attempts a direct dial (if we have a public address of that peer), and then
// coordinates a hole punch over the given relay connection.
func (hs *Service) DirectConnect(p peer.ID) error {
log.Debugw("got inbound proxy conn", "peer", p)
if err := hs.beginDirectConnect(p); err != nil {
return err
}
Expand Down Expand Up @@ -221,8 +255,16 @@ func (hs *Service) directConnect(rp peer.ID) error {
}
}

if len(hs.ids.OwnObservedAddrs()) == 0 {
log.Debugw("got inbound proxy conn", "peer", rp)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
select {
case <-hs.ctx.Done():
return hs.ctx.Err()
case <-ctx.Done():
log.Debug("didn't find any public host address")
return errors.New("can't initiate hole punch, as we don't have any public addresses")
case <-hs.hasPublicAddrsChan:
}

// hole punch
Expand Down Expand Up @@ -250,7 +292,6 @@ func (hs *Service) directConnect(rp peer.ID) error {
dt := time.Since(start)
hs.tracer.EndHolePunch(rp, dt, err)
if err == nil {
log.Debugw("hole punching with successful", "peer", rp, "time", dt)
return nil
}
case <-hs.ctx.Done():
Expand Down Expand Up @@ -341,11 +382,6 @@ func (hs *Service) handleNewStream(s network.Stream) {
err = hs.holePunchConnect(pi, false)
dt := time.Since(start)
hs.tracer.EndHolePunch(rp, dt, err)
if err != nil {
log.Debugw("hole punching failed", "peer", rp, "time", dt, "error", err)
} else {
log.Debugw("hole punching succeeded", "peer", rp, "time", dt)
}
}

func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error {
Expand All @@ -363,6 +399,16 @@ func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error {
return nil
}

func containsPublicAddr(addrs []ma.Multiaddr) bool {
for _, addr := range addrs {
if isRelayAddress(addr) || !manet.IsPublicAddr(addr) {
continue
}
return true
}
return false
}

func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
result := make([]ma.Multiaddr, 0, len(addrs))
for _, addr := range addrs {
Expand Down Expand Up @@ -414,6 +460,7 @@ func (nn *netNotifiee) Connected(_ network.Network, conn network.Conn) {
// that we can dial to for a hole punch.
case <-hs.ids.IdentifyWait(conn):
case <-hs.ctx.Done():
return
}

_ = hs.DirectConnect(conn.RemotePeer())
Expand Down
52 changes: 43 additions & 9 deletions p2p/protocol/holepunch/coordination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ func (m *mockEventTracer) getEvents() []*holepunch.Event {

var _ holepunch.EventTracer = &mockEventTracer{}

type mockIDService struct {
identify.IDService
}

var _ identify.IDService = &mockIDService{}

func newMockIDService(t *testing.T, h host.Host) identify.IDService {
ids, err := identify.NewIDService(h)
require.NoError(t, err)
return &mockIDService{IDService: ids}
}

func (s *mockIDService) OwnObservedAddrs() []ma.Multiaddr {
return append(s.IDService.OwnObservedAddrs(), ma.StringCast("/ip4/1.1.1.1/tcp/1234"))
}

func TestNoHolePunchIfDirectConnExists(t *testing.T) {
tr := &mockEventTracer{}
h1, hps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr))
Expand Down Expand Up @@ -180,6 +196,14 @@ func TestFailuresOnInitiator(t *testing.T) {
}
}

func addrsToBytes(as []ma.Multiaddr) [][]byte {
bzs := make([][]byte, 0, len(as))
for _, a := range as {
bzs = append(bzs, a.Bytes())
}
return bzs
}

func TestFailuresOnResponder(t *testing.T) {
tcs := map[string]struct {
initiator func(s network.Stream)
Expand All @@ -192,22 +216,36 @@ func TestFailuresOnResponder(t *testing.T) {
},
errMsg: "expected CONNECT message",
},
"initiator does NOT send a SYNC message after a Connect message": {
"initiator does NOT send a SYNC message after a CONNECT message": {
initiator: func(s network.Stream) {
w := protoio.NewDelimitedWriter(s)
w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
w.WriteMsg(&holepunch_pb.HolePunch{
Type: holepunch_pb.HolePunch_CONNECT.Enum(),
ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}),
})
w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
},
errMsg: "expected SYNC message",
},
"initiator does NOT reply within hole punch deadline": {
holePunchTimeout: 10 * time.Millisecond,
initiator: func(s network.Stream) {
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{
Type: holepunch_pb.HolePunch_CONNECT.Enum(),
ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}),
})
time.Sleep(10 * time.Second)
},
errMsg: "i/o deadline reached",
},
"initiator does NOT send any addresses in CONNECT": {
holePunchTimeout: 10 * time.Millisecond,
initiator: func(s network.Stream) {
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
time.Sleep(10 * time.Second)
},
errMsg: "expected CONNECT message to contain at least one message",
},
}

for name, tc := range tcs {
Expand Down Expand Up @@ -361,9 +399,7 @@ func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool)

func addHolePunchService(t *testing.T, h host.Host) *holepunch.Service {
t.Helper()
ids, err := identify.NewIDService(h)
require.NoError(t, err)
hps, err := holepunch.NewService(h, ids)
hps, err := holepunch.NewService(h, newMockIDService(t, h))
require.NoError(t, err)
return hps
}
Expand All @@ -372,9 +408,7 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host,
t.Helper()
h, err := libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0")))
require.NoError(t, err)
ids, err := identify.NewIDService(h)
require.NoError(t, err)
hps, err := holepunch.NewService(h, ids, opts...)
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
require.NoError(t, err)
return h, hps
}

0 comments on commit 4e776e5

Please sign in to comment.