diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 83749778d2..66930a3b46 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -50,6 +50,8 @@ type IDService struct { currid map[inet.Conn]chan struct{} currmu sync.RWMutex + addrMu sync.Mutex + // our own observed addresses. // TODO: instead of expiring, remove these when we disconnect observedAddrs ObservedAddrSet @@ -63,6 +65,7 @@ func NewIDService(h host.Host) *IDService { currid: make(map[inet.Conn]chan struct{}), } h.SetStreamHandler(ID, s.RequestHandler) + h.Network().Notify((*netNotifiee)(s)) return s } @@ -220,9 +223,17 @@ func (ids *IDService) consumeMessage(mes *pb.Identify, c inet.Conn) { lmaddrs = append(lmaddrs, c.RemoteMultiaddr()) } - // update our peerstore with the addresses. here, we SET the addresses, clearing old ones. - // We are receiving from the peer itself. this is current address ground truth. - ids.Host.Peerstore().SetAddrs(p, lmaddrs, pstore.ConnectedAddrTTL) + // Extend the TTLs on the known (probably) good addresses. + // Taking the lock ensures that we don't concurrently process a disconnect. + ids.addrMu.Lock() + switch ids.Host.Network().Connectedness(p) { + case inet.Connected: + ids.Host.Peerstore().AddAddrs(p, lmaddrs, pstore.ConnectedAddrTTL) + default: + ids.Host.Peerstore().AddAddrs(p, lmaddrs, pstore.RecentlyConnectedAddrTTL) + } + ids.addrMu.Unlock() + log.Debugf("%s received listen addrs for %s: %s", c.LocalPeer(), c.RemotePeer(), lmaddrs) // get protocol versions @@ -449,9 +460,14 @@ func (nn *netNotifiee) Connected(n inet.Network, v inet.Conn) { func (nn *netNotifiee) Disconnected(n inet.Network, v inet.Conn) { // undo the setting of addresses to peer.ConnectedAddrTTL we did ids := nn.IDService() - ps := ids.Host.Peerstore() - addrs := ps.Addrs(v.RemotePeer()) - ps.SetAddrs(v.RemotePeer(), addrs, pstore.RecentlyConnectedAddrTTL) + ids.addrMu.Lock() + defer ids.addrMu.Unlock() + + if ids.Host.Network().Connectedness(v.RemotePeer()) != inet.Connected { + // Last disconnect. + ps := ids.Host.Peerstore() + ps.UpdateAddrs(v.RemotePeer(), pstore.ConnectedAddrTTL, pstore.RecentlyConnectedAddrTTL) + } } func (nn *netNotifiee) OpenedStream(n inet.Network, v inet.Stream) {} diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index d46340bf10..ed2a1beced 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -8,6 +8,7 @@ import ( ic "github.com/libp2p/go-libp2p-crypto" testutil "github.com/libp2p/go-libp2p-netutil" peer "github.com/libp2p/go-libp2p-peer" + pstore "github.com/libp2p/go-libp2p-peerstore" identify "github.com/libp2p/go-libp2p/p2p/protocol/identify" blhost "github.com/libp2p/go-libp2p-blankhost" @@ -15,9 +16,10 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -func subtestIDService(t *testing.T, postDialWait time.Duration) { +func subtestIDService(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - ctx := context.Background() h1 := blhost.NewBlankHost(testutil.GenSwarmNetwork(t, ctx)) h2 := blhost.NewBlankHost(testutil.GenSwarmNetwork(t, ctx)) @@ -30,6 +32,11 @@ func subtestIDService(t *testing.T, postDialWait time.Duration) { testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) // nothing testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) // nothing + forgetMe, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") + + h2.Peerstore().AddAddr(h1p, forgetMe, pstore.RecentlyConnectedAddrTTL) + time.Sleep(50 * time.Millisecond) + h2pi := h2.Peerstore().PeerInfo(h2p) if err := h1.Connect(ctx, h2pi); err != nil { t.Fatal(err) @@ -58,16 +65,39 @@ func subtestIDService(t *testing.T, postDialWait time.Duration) { ids2.IdentifyConn(c[0]) addrs := h1.Peerstore().Addrs(h1p) - addrs = append(addrs, c[0].RemoteMultiaddr()) + addrs = append(addrs, c[0].RemoteMultiaddr(), forgetMe) // and the protocol versions. t.Log("test peer2 has peer1 addrs correctly") testKnowsAddrs(t, h2, h1p, addrs) // has them testHasProtocolVersions(t, h2, h1p) testHasPublicKey(t, h2, h1p, h1.Peerstore().PubKey(h1p)) // h1 should have h2's public key + + // Need both sides to actually notice that the connection has been closed. + h1.Network().ClosePeer(h2p) + h2.Network().ClosePeer(h1p) + if len(h2.Network().ConnsToPeer(h1.ID())) != 0 || len(h1.Network().ConnsToPeer(h2.ID())) != 0 { + t.Fatal("should have no connections") + } + + testKnowsAddrs(t, h2, h1p, addrs) + testKnowsAddrs(t, h1, h2p, h2.Peerstore().Addrs(h2p)) + + time.Sleep(50 * time.Millisecond) + + // Forget the first one. + testKnowsAddrs(t, h2, h1p, addrs[:len(addrs)-1]) + + time.Sleep(50 * time.Millisecond) + + // Forget the rest. + testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) + testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) } func testKnowsAddrs(t *testing.T, h host.Host, p peer.ID, expected []ma.Multiaddr) { + t.Helper() + actual := h.Peerstore().Addrs(p) if len(actual) != len(expected) { @@ -125,17 +155,16 @@ func testHasPublicKey(t *testing.T, h host.Host, p peer.ID, shouldBe ic.PubKey) // TestIDServiceWait gives the ID service 100ms to finish after dialing // this is becasue it used to be concurrent. Now, Dial wait till the // id service is done. -func TestIDServiceWait(t *testing.T) { - N := 3 - for i := 0; i < N; i++ { - subtestIDService(t, 100*time.Millisecond) - } -} +func TestIDService(t *testing.T) { + oldTTL := pstore.RecentlyConnectedAddrTTL + pstore.RecentlyConnectedAddrTTL = 100 * time.Millisecond + defer func() { + pstore.RecentlyConnectedAddrTTL = oldTTL + }() -func TestIDServiceNoWait(t *testing.T) { N := 3 for i := 0; i < N; i++ { - subtestIDService(t, 0) + subtestIDService(t) } }