diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index cd19e726ed..2e7e74ee51 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -329,12 +329,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, } c.streams.m = make(map[*Stream]struct{}) - if len(s.conns.m[p]) == 0 { // first connection - s.emitter.Emit(event.EvtPeerConnectednessChanged{ - Peer: p, - Connectedness: network.Connected, - }) - } + isFirstConnection := len(s.conns.m[p]) == 0 s.conns.m[p] = append(s.conns.m[p], c) // Add two swarm refs: @@ -347,6 +342,15 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, c.notifyLk.Lock() s.conns.Unlock() + // Emit event after releasing `s.conns` lock so that a consumer can still + // use swarm methods that need the `s.conns` lock. + if isFirstConnection { + s.emitter.Emit(event.EvtPeerConnectednessChanged{ + Peer: p, + Connectedness: network.Connected, + }) + } + s.notifyAll(func(f network.Notifiee) { f.Connected(s, c) }) @@ -626,25 +630,32 @@ func (s *Swarm) removeConn(c *Conn) { p := c.RemotePeer() s.conns.Lock() - defer s.conns.Unlock() cs := s.conns.m[p] + + if len(cs) == 1 { + delete(s.conns.m, p) + s.conns.Unlock() + + // Emit event after releasing `s.conns` lock so that a consumer can still + // use swarm methods that need the `s.conns` lock. + s.emitter.Emit(event.EvtPeerConnectednessChanged{ + Peer: p, + Connectedness: network.NotConnected, + }) + return + } + + defer s.conns.Unlock() + for i, ci := range cs { if ci == c { - if len(cs) == 1 { - delete(s.conns.m, p) - s.emitter.Emit(event.EvtPeerConnectednessChanged{ - Peer: p, - Connectedness: network.NotConnected, - }) - } else { - // NOTE: We're intentionally preserving order. - // This way, connections to a peer are always - // sorted oldest to newest. - copy(cs[i:], cs[i+1:]) - cs[len(cs)-1] = nil - s.conns.m[p] = cs[:len(cs)-1] - } + // NOTE: We're intentionally preserving order. + // This way, connections to a peer are always + // sorted oldest to newest. + copy(cs[i:], cs[i+1:]) + cs[len(cs)-1] = nil + s.conns.m[p] = cs[:len(cs)-1] break } } diff --git a/p2p/net/swarm/swarm_event_test.go b/p2p/net/swarm/swarm_event_test.go index 7d4fb6bd5d..86d698d611 100644 --- a/p2p/net/swarm/swarm_event_test.go +++ b/p2p/net/swarm/swarm_event_test.go @@ -64,3 +64,52 @@ func TestConnectednessEventsSingleConn(t *testing.T) { checkEvent(t, sub1, event.EvtPeerConnectednessChanged{Peer: s2.LocalPeer(), Connectedness: network.NotConnected}) checkEvent(t, sub2, event.EvtPeerConnectednessChanged{Peer: s1.LocalPeer(), Connectedness: network.NotConnected}) } + +func TestNoDeadlockWhenConsumingConnectednessEvents(t *testing.T) { + dialerEventBus := eventbus.NewBus() + dialer := swarmt.GenSwarm(t, swarmt.OptDialOnly, swarmt.EventBus(dialerEventBus)) + defer dialer.Close() + + listener := swarmt.GenSwarm(t, swarmt.OptDialOnly) + addrsToListen := []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), + } + + if err := listener.Listen(addrsToListen...); err != nil { + t.Fatal(err) + } + listenedAddrs := listener.ListenAddresses() + + dialer.Peerstore().AddAddrs(listener.LocalPeer(), listenedAddrs, time.Hour) + + sub, err := dialerEventBus.Subscribe(new(event.EvtPeerConnectednessChanged)) + require.NoError(t, err) + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // A slow consumer + go func() { + for { + select { + case <-ctx.Done(): + return + case <-sub.Out(): + time.Sleep(100 * time.Millisecond) + // Do something with the swarm that needs the conns lock + _ = dialer.ConnsToPeer(listener.LocalPeer()) + time.Sleep(100 * time.Millisecond) + } + } + }() + + for i := 0; i < 10; i++ { + // Connect and disconnect to trigger a bunch of events + _, err := dialer.DialPeer(context.Background(), listener.LocalPeer()) + require.NoError(t, err) + dialer.ClosePeer(listener.LocalPeer()) + } + + // The test should finish without deadlocking +} diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index b122697ec9..54645e5295 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -48,8 +48,7 @@ const ServiceName = "libp2p.identify" const maxPushConcurrency = 32 -// StreamReadTimeout is the read timeout on all incoming Identify family streams. -var StreamReadTimeout = 60 * time.Second +var Timeout = 60 * time.Second // timeout on all incoming Identify interactions const ( legacyIDSize = 2 * 1024 // 2k Bytes @@ -416,11 +415,14 @@ func (ids *idService) IdentifyWait(c network.Conn) <-chan struct{} { } func (ids *idService) identifyConn(c network.Conn) error { - s, err := c.NewStream(network.WithUseTransient(context.TODO(), "identify")) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) + defer cancel() + s, err := c.NewStream(network.WithUseTransient(ctx, "identify")) if err != nil { log.Debugw("error opening identify stream", "peer", c.RemotePeer(), "error", err) return err } + s.SetDeadline(time.Now().Add(Timeout)) if err := s.SetProtocol(ID); err != nil { log.Warnf("error setting identify protocol for stream: %s", err) @@ -439,6 +441,7 @@ func (ids *idService) identifyConn(c network.Conn) error { // handlePush handles incoming identify push streams func (ids *idService) handlePush(s network.Stream) { + s.SetDeadline(time.Now().Add(Timeout)) ids.handleIdentifyResponse(s, true) } @@ -500,8 +503,6 @@ func (ids *idService) handleIdentifyResponse(s network.Stream, isPush bool) erro } defer s.Scope().ReleaseMemory(signedIDSize) - _ = s.SetReadDeadline(time.Now().Add(StreamReadTimeout)) - c := s.Conn() r := pbio.NewDelimitedReader(s, signedIDSize) diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 0bdaae033e..feb9c36db8 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -804,10 +804,10 @@ func TestLargePushMessage(t *testing.T) { } func TestIdentifyResponseReadTimeout(t *testing.T) { - timeout := identify.StreamReadTimeout - identify.StreamReadTimeout = 100 * time.Millisecond + timeout := identify.Timeout + identify.Timeout = 100 * time.Millisecond defer func() { - identify.StreamReadTimeout = timeout + identify.Timeout = timeout }() ctx, cancel := context.WithCancel(context.Background()) @@ -850,10 +850,10 @@ func TestIdentifyResponseReadTimeout(t *testing.T) { } func TestIncomingIDStreamsTimeout(t *testing.T) { - timeout := identify.StreamReadTimeout - identify.StreamReadTimeout = 100 * time.Millisecond + timeout := identify.Timeout + identify.Timeout = 100 * time.Millisecond defer func() { - identify.StreamReadTimeout = timeout + identify.Timeout = timeout }() ctx, cancel := context.WithCancel(context.Background()) diff --git a/version.json b/version.json index 83fde75bba..d5ae9b9c33 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v0.27.6" + "version": "v0.27.7" }