diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 6ef0031311..1cb102d0c5 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -19,6 +19,7 @@ package network import ( "errors" "sync" + "time" "unsafe" libp2pnetwork "github.com/libp2p/go-libp2p-core/network" @@ -29,6 +30,7 @@ import ( var errCannotValidateHandshake = errors.New("failed to validate handshake") const maxHandshakeSize = unsafe.Sizeof(BlockAnnounceHandshake{}) //nolint +const handshakeTimeout = time.Second * 10 // Handshake is the interface all handshakes for notifications protocols must implement type Handshake interface { @@ -53,6 +55,11 @@ type ( NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) (propagate bool, err error) ) +type handshakeReader struct { + hs Handshake + err error +} + type notificationsProtocol struct { protocolID protocol.ID getHandshake HandshakeGetter @@ -63,16 +70,17 @@ type notificationsProtocol struct { } func (n *notificationsProtocol) getHandshakeData(pid peer.ID, inbound bool) (handshakeData, bool) { - if inbound { - data, has := n.inboundHandshakeData.Load(pid) - if !has { - return handshakeData{}, false - } + var ( + data interface{} + has bool + ) - return data.(handshakeData), true + if inbound { + data, has = n.inboundHandshakeData.Load(pid) + } else { + data, has = n.outboundHandshakeData.Load(pid) } - data, has := n.outboundHandshakeData.Load(pid) if !has { return handshakeData{}, false } @@ -174,7 +182,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return nil } - logger.Debug("received message on notifications sub-protocol", "protocol", info.protocolID, + logger.Trace("received message on notifications sub-protocol", "protocol", info.protocolID, "message", msg, "peer", stream.Conn().RemotePeer(), ) @@ -231,14 +239,29 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc return } - hs, err := s.readHandshake(stream, decodeBlockAnnounceHandshake) - if err != nil { - logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err) + hsTimer := time.NewTimer(handshakeTimeout) + + var hs Handshake + select { + case <-hsTimer.C: + logger.Trace("handshake timeout reached", "protocol", info.protocolID, "peer", peer) _ = stream.Close() + info.outboundHandshakeData.Delete(peer) return - } - hsData.received = true + case hsResponse := <-s.readHandshake(stream, decodeBlockAnnounceHandshake): + hsTimer.Stop() + if hsResponse.err != nil { + logger.Trace("failed to read handshake", "protocol", info.protocolID, "peer", peer, "error", err) + _ = stream.Close() + + info.outboundHandshakeData.Delete(peer) + return + } + + hs = hsResponse.hs + hsData.received = true + } err = info.handshakeValidator(peer, hs) if err != nil { @@ -299,19 +322,30 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer } } -func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) (Handshake, error) { - msgBytes := s.bufPool.get() - defer s.bufPool.put(&msgBytes) +func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDecoder) <-chan *handshakeReader { + hsC := make(chan *handshakeReader) - tot, err := readStream(stream, msgBytes[:]) - if err != nil { - return nil, err - } + go func() { + msgBytes := s.bufPool.get() + defer func() { + s.bufPool.put(&msgBytes) + close(hsC) + }() - hs, err := decoder(msgBytes[:tot]) - if err != nil { - return nil, err - } + tot, err := readStream(stream, msgBytes[:]) + if err != nil { + hsC <- &handshakeReader{hs: nil, err: err} + return + } + + hs, err := decoder(msgBytes[:tot]) + if err != nil { + hsC <- &handshakeReader{hs: nil, err: err} + return + } + + hsC <- &handshakeReader{hs: hs, err: nil} + }() - return hs, nil + return hsC } diff --git a/dot/network/notifications_test.go b/dot/network/notifications_test.go index 83926edaef..3662943614 100644 --- a/dot/network/notifications_test.go +++ b/dot/network/notifications_test.go @@ -17,6 +17,8 @@ package network import ( + "context" + "fmt" "math/big" "sync" "testing" @@ -25,7 +27,10 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/utils" + ma "github.com/multiformats/go-multiaddr" + "github.com/libp2p/go-libp2p" + libp2pnetwork "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/require" ) @@ -240,3 +245,80 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) require.True(t, data.received) require.True(t, data.validated) } + +func Test_HandshakeTimeout(t *testing.T) { + // create service A + config := &Config{ + BasePath: utils.NewTestBasePath(t, "nodeA"), + Port: 7001, + RandSeed: 1, + NoBootstrap: true, + NoMDNS: true, + } + ha := createTestService(t, config) + + // create info and handler + info := ¬ificationsProtocol{ + protocolID: ha.host.protocolID + blockAnnounceID, + getHandshake: ha.getBlockAnnounceHandshake, + handshakeValidator: ha.validateBlockAnnounceHandshake, + inboundHandshakeData: new(sync.Map), + outboundHandshakeData: new(sync.Map), + } + + // creating host b with will never respond to a handshake + addrB, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", 7002)) + require.NoError(t, err) + + hb, err := libp2p.New( + context.Background(), libp2p.ListenAddrs(addrB), + ) + require.NoError(t, err) + + testHandshakeMsg := &BlockAnnounceHandshake{ + Roles: 4, + BestBlockNumber: 77, + BestBlockHash: common.Hash{1}, + GenesisHash: common.Hash{2}, + } + + hb.SetStreamHandler(info.protocolID, func(stream libp2pnetwork.Stream) { + fmt.Println("never respond a handshake message") + }) + + addrBInfo := peer.AddrInfo{ + ID: hb.ID(), + Addrs: hb.Addrs(), + } + + err = ha.host.connect(addrBInfo) + if failedToDial(err) { + time.Sleep(TestBackoffTimeout) + err = ha.host.connect(addrBInfo) + } + require.NoError(t, err) + + go ha.sendData(hb.ID(), testHandshakeMsg, info, nil) + + time.Sleep(handshakeTimeout / 2) + // peer should be stored in handshake data until timeout + _, ok := info.outboundHandshakeData.Load(hb.ID()) + require.True(t, ok) + + // a stream should be open until timeout + connAToB := ha.host.h.Network().ConnsToPeer(hb.ID()) + require.Len(t, connAToB, 1) + require.Len(t, connAToB[0].GetStreams(), 1) + + // after the timeout + time.Sleep(handshakeTimeout) + + // handshake data should be removed + _, ok = info.outboundHandshakeData.Load(hb.ID()) + require.False(t, ok) + + // stream should be closed + connAToB = ha.host.h.Network().ConnsToPeer(hb.ID()) + require.Len(t, connAToB, 1) + require.Len(t, connAToB[0].GetStreams(), 0) +} diff --git a/dot/network/utils.go b/dot/network/utils.go index 211f247a8f..67f60ddf5d 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -216,7 +216,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { } if tot != int(length) { - return tot, fmt.Errorf("failed to read entire message: expected %d bytes", length) + return tot, fmt.Errorf("failed to read entire message: expected %d bytes, received %d bytes", length, tot) } return tot, nil