diff --git a/p2p/protocol/holepunch/coordination.go b/p2p/protocol/holepunch/coordination.go index 9b3f328b44..57b1aec2e6 100644 --- a/p2p/protocol/holepunch/coordination.go +++ b/p2p/protocol/holepunch/coordination.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -56,6 +57,8 @@ type Service struct { closed bool refCount sync.WaitGroup + hasPublicAddrsChan chan struct{} // this chan is closed as soon as we have a public address + // active hole punches for deduplicating activeMx sync.Mutex active map[peer.ID]struct{} @@ -71,11 +74,12 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, ctx, cancel := context.WithCancel(context.Background()) hs := &Service{ - ctx: ctx, - ctxCancel: cancel, - host: h, - ids: ids, - active: make(map[peer.ID]struct{}), + ctx: ctx, + ctxCancel: cancel, + host: h, + ids: ids, + active: make(map[peer.ID]struct{}), + hasPublicAddrsChan: make(chan struct{}), } for _, opt := range opts { @@ -85,11 +89,42 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, } } - h.SetStreamHandler(Protocol, hs.handleNewStream) + sub, err := hs.host.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{}) + if err != nil { + return nil, err + } + hs.refCount.Add(1) + go hs.watchForPublicAddr(sub) + h.Network().Notify((*netNotifiee)(hs)) return hs, nil } +func (hs *Service) watchForPublicAddr(sub event.Subscription) { + defer hs.refCount.Done() + defer sub.Close() + + log.Debug("waiting until we have at least one public address", "peer", hs.host.ID()) + + for { + if containsPublicAddr(hs.ids.OwnObservedAddrs()) { + log.Debug("Host now has a public address. Starting holepunch protocol.") + hs.host.SetStreamHandler(Protocol, hs.handleNewStream) + close(hs.hasPublicAddrsChan) + return + } + + select { + case <-hs.ctx.Done(): + return + case _, ok := <-sub.Out(): + if !ok { + return + } + } + } +} + // Close closes the Hole Punch Service. func (hs *Service) Close() error { hs.closeMx.Lock() @@ -176,7 +211,6 @@ func (hs *Service) beginDirectConnect(p peer.ID) error { // It first attempts a direct dial (if we have a public address of that peer), and then // coordinates a hole punch over the given relay connection. func (hs *Service) DirectConnect(p peer.ID) error { - log.Debugw("got inbound proxy conn", "peer", p) if err := hs.beginDirectConnect(p); err != nil { return err } @@ -221,8 +255,16 @@ func (hs *Service) directConnect(rp peer.ID) error { } } - if len(hs.ids.OwnObservedAddrs()) == 0 { + log.Debugw("got inbound proxy conn", "peer", rp) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + select { + case <-hs.ctx.Done(): + return hs.ctx.Err() + case <-ctx.Done(): + log.Debug("didn't find any public host address") return errors.New("can't initiate hole punch, as we don't have any public addresses") + case <-hs.hasPublicAddrsChan: } // hole punch @@ -250,7 +292,6 @@ func (hs *Service) directConnect(rp peer.ID) error { dt := time.Since(start) hs.tracer.EndHolePunch(rp, dt, err) if err == nil { - log.Debugw("hole punching with successful", "peer", rp, "time", dt) return nil } case <-hs.ctx.Done(): @@ -341,11 +382,6 @@ func (hs *Service) handleNewStream(s network.Stream) { err = hs.holePunchConnect(pi, false) dt := time.Since(start) hs.tracer.EndHolePunch(rp, dt, err) - if err != nil { - log.Debugw("hole punching failed", "peer", rp, "time", dt, "error", err) - } else { - log.Debugw("hole punching succeeded", "peer", rp, "time", dt) - } } func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error { @@ -363,6 +399,16 @@ func (hs *Service) holePunchConnect(pi peer.AddrInfo, isClient bool) error { return nil } +func containsPublicAddr(addrs []ma.Multiaddr) bool { + for _, addr := range addrs { + if isRelayAddress(addr) || !manet.IsPublicAddr(addr) { + continue + } + return true + } + return false +} + func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { result := make([]ma.Multiaddr, 0, len(addrs)) for _, addr := range addrs { @@ -414,6 +460,7 @@ func (nn *netNotifiee) Connected(_ network.Network, conn network.Conn) { // that we can dial to for a hole punch. case <-hs.ids.IdentifyWait(conn): case <-hs.ctx.Done(): + return } _ = hs.DirectConnect(conn.RemotePeer()) diff --git a/p2p/protocol/holepunch/coordination_test.go b/p2p/protocol/holepunch/coordination_test.go index 4786e3416d..da40c75751 100644 --- a/p2p/protocol/holepunch/coordination_test.go +++ b/p2p/protocol/holepunch/coordination_test.go @@ -47,6 +47,22 @@ func (m *mockEventTracer) getEvents() []*holepunch.Event { var _ holepunch.EventTracer = &mockEventTracer{} +type mockIDService struct { + identify.IDService +} + +var _ identify.IDService = &mockIDService{} + +func newMockIDService(t *testing.T, h host.Host) identify.IDService { + ids, err := identify.NewIDService(h) + require.NoError(t, err) + return &mockIDService{IDService: ids} +} + +func (s *mockIDService) OwnObservedAddrs() []ma.Multiaddr { + return append(s.IDService.OwnObservedAddrs(), ma.StringCast("/ip4/1.1.1.1/tcp/1234")) +} + func TestNoHolePunchIfDirectConnExists(t *testing.T) { tr := &mockEventTracer{} h1, hps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr)) @@ -180,6 +196,14 @@ func TestFailuresOnInitiator(t *testing.T) { } } +func addrsToBytes(as []ma.Multiaddr) [][]byte { + bzs := make([][]byte, 0, len(as)) + for _, a := range as { + bzs = append(bzs, a.Bytes()) + } + return bzs +} + func TestFailuresOnResponder(t *testing.T) { tcs := map[string]struct { initiator func(s network.Stream) @@ -192,10 +216,13 @@ func TestFailuresOnResponder(t *testing.T) { }, errMsg: "expected CONNECT message", }, - "initiator does NOT send a SYNC message after a Connect message": { + "initiator does NOT send a SYNC message after a CONNECT message": { initiator: func(s network.Stream) { w := protoio.NewDelimitedWriter(s) - w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) + w.WriteMsg(&holepunch_pb.HolePunch{ + Type: holepunch_pb.HolePunch_CONNECT.Enum(), + ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}), + }) w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) }, errMsg: "expected SYNC message", @@ -203,11 +230,22 @@ func TestFailuresOnResponder(t *testing.T) { "initiator does NOT reply within hole punch deadline": { holePunchTimeout: 10 * time.Millisecond, initiator: func(s network.Stream) { - protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) + protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{ + Type: holepunch_pb.HolePunch_CONNECT.Enum(), + ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}), + }) time.Sleep(10 * time.Second) }, errMsg: "i/o deadline reached", }, + "initiator does NOT send any addresses in CONNECT": { + holePunchTimeout: 10 * time.Millisecond, + initiator: func(s network.Stream) { + protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()}) + time.Sleep(10 * time.Second) + }, + errMsg: "expected CONNECT message to contain at least one message", + }, } for name, tc := range tcs { @@ -361,9 +399,7 @@ func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool) func addHolePunchService(t *testing.T, h host.Host) *holepunch.Service { t.Helper() - ids, err := identify.NewIDService(h) - require.NoError(t, err) - hps, err := holepunch.NewService(h, ids) + hps, err := holepunch.NewService(h, newMockIDService(t, h)) require.NoError(t, err) return hps } @@ -372,9 +408,7 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, t.Helper() h, err := libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0"))) require.NoError(t, err) - ids, err := identify.NewIDService(h) - require.NoError(t, err) - hps, err := holepunch.NewService(h, ids, opts...) + hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...) require.NoError(t, err) return h, hps }