diff --git a/dot/network/host.go b/dot/network/host.go index 125e7f4781..84cbd8bcc2 100644 --- a/dot/network/host.go +++ b/dot/network/host.go @@ -362,6 +362,17 @@ func (h *host) peers() []peer.ID { return h.h.Network().Peers() } +// supportsProtocol checks if the protocol is supported by peerID +// returns an error if could not get peer protocols +func (h *host) supportsProtocol(peerID peer.ID, protocol protocol.ID) (bool, error) { + peerProtocols, err := h.h.Peerstore().SupportsProtocols(peerID, string(protocol)) + if err != nil { + return false, err + } + + return len(peerProtocols) > 0, nil +} + // peerCount returns the number of connected peers func (h *host) peerCount() int { peers := h.h.Network().Peers() diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 1d0b17e621..7330f71441 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -23,6 +23,7 @@ import ( "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/utils" + "github.com/libp2p/go-libp2p-core/protocol" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" @@ -366,3 +367,71 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { _, ok = info.getHandshakeData(nodeB.host.id(), true) require.False(t, ok) } + +func Test_PeerSupportsProtocol(t *testing.T) { + basePathA := utils.NewTestBasePath(t, "nodeA") + configA := &Config{ + BasePath: basePathA, + Port: 7001, + RandSeed: 1, + NoBootstrap: true, + NoMDNS: true, + } + + nodeA := createTestService(t, configA) + + basePathB := utils.NewTestBasePath(t, "nodeB") + configB := &Config{ + BasePath: basePathB, + Port: 7002, + RandSeed: 2, + NoBootstrap: true, + NoMDNS: true, + } + + nodeB := createTestService(t, configB) + nodeB.noGossip = true + + addrInfosB, err := nodeB.host.addrInfos() + require.NoError(t, err) + + err = nodeA.host.connect(*addrInfosB[0]) + // retry connect if "failed to dial" error + if failedToDial(err) { + time.Sleep(TestBackoffTimeout) + err = nodeA.host.connect(*addrInfosB[0]) + } + require.NoError(t, err) + + tests := []struct { + protocol protocol.ID + expect bool + }{ + { + protocol: protocol.ID("/gossamer/test/0/sync/2"), + expect: true, + }, + { + protocol: protocol.ID("/gossamer/test/0/light/2"), + expect: true, + }, + { + protocol: protocol.ID("/gossamer/test/0/block-announces/1"), + expect: true, + }, + { + protocol: protocol.ID("/gossamer/test/0/transactions/1"), + expect: true, + }, + { + protocol: protocol.ID("/gossamer/not_supported/protocol"), + expect: false, + }, + } + + for _, test := range tests { + output, err := nodeA.host.supportsProtocol(nodeB.host.id(), test.protocol) + require.NoError(t, err) + require.Equal(t, test.expect, output) + } +} diff --git a/dot/network/notifications.go b/dot/network/notifications.go index c711534d1d..1cb102d0c5 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -206,6 +206,11 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, } func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtocol, msg NotificationsMessage) { + if support, err := s.host.supportsProtocol(peer, info.protocolID); err != nil || !support { + logger.Debug("the peer does not supports the protocol", "protocol", info.protocolID, "peer", peer, "err", err) + return + } + hsData, has := info.getHandshakeData(peer, false) if has && !hsData.validated { // peer has sent us an invalid handshake in the past, ignore