diff --git a/dot/network/block_announce.go b/dot/network/block_announce.go index 6d22b42b59..c18be68b2a 100644 --- a/dot/network/block_announce.go +++ b/dot/network/block_announce.go @@ -222,7 +222,7 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err // `createNotificationsMessageHandler` which locks the map beforehand. data, ok := np.getHandshakeData(peer) if !ok { - np.handshakeData.Store(peer, &handshakeData{ + np.handshakeData.Store(peer, handshakeData{ received: true, validated: true, }) @@ -230,6 +230,7 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err } data.handshake = hs + np.handshakeData.Store(peer, data) // if peer has higher best block than us, begin syncing latestHeader, err := s.blockState.BestBlockHeader() diff --git a/dot/network/block_announce_test.go b/dot/network/block_announce_test.go index ea381dad6d..79c01f2b4c 100644 --- a/dot/network/block_announce_test.go +++ b/dot/network/block_announce_test.go @@ -120,7 +120,7 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) { handshakeData: new(sync.Map), } testPeerID := peer.ID("noot") - nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, &handshakeData{}) + nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, handshakeData{}) err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{ BestBlockNumber: 100, diff --git a/dot/network/gossip_test.go b/dot/network/gossip_test.go index 73fe5f55c0..94c7b0669e 100644 --- a/dot/network/gossip_test.go +++ b/dot/network/gossip_test.go @@ -101,7 +101,7 @@ func TestGossip(t *testing.T) { } require.NoError(t, err) - err = nodeA.host.send(addrInfosB[0].ID, "", testBlockAnnounceMessage) + _, err = nodeA.host.send(addrInfosB[0].ID, "", testBlockAnnounceMessage) require.NoError(t, err) time.Sleep(TestMessageTimeout) diff --git a/dot/network/host.go b/dot/network/host.go index 89728ff907..d3920f7049 100644 --- a/dot/network/host.go +++ b/dot/network/host.go @@ -262,32 +262,26 @@ func (h *host) bootstrap() { } } -// send writes the given message to the outbound message stream for the given -// peer (gets the already opened outbound message stream or opens a new one). -func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) { - // get outbound stream for given peer - s := h.getOutboundStream(p, pid) - - // check if stream needs to be opened - if s == nil { - // open outbound stream with host protocol id - s, err = h.h.NewStream(h.ctx, p, pid) - if err != nil { - logger.Trace("failed to open new stream with peer", "peer", p, "protocol", pid, "error", err) - return err - } - - logger.Trace( - "Opened stream", - "host", h.id(), - "peer", p, - "protocol", pid, - ) +// send creates a new outbound stream with the given peer and writes the message. It also returns +// the newly created stream. +func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (libp2pnetwork.Stream, error) { + // open outbound stream with host protocol id + stream, err := h.h.NewStream(h.ctx, p, pid) + if err != nil { + logger.Trace("failed to open new stream with peer", "peer", p, "protocol", pid, "error", err) + return nil, err } - err = h.writeToStream(s, msg) + logger.Trace( + "Opened stream", + "host", h.id(), + "peer", p, + "protocol", pid, + ) + + err = h.writeToStream(stream, msg) if err != nil { - return err + return nil, err } logger.Trace( @@ -298,7 +292,7 @@ func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) { "message", msg.String(), ) - return nil + return stream, nil } func (h *host) writeToStream(s libp2pnetwork.Stream, msg Message) error { diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 1bd3aec938..fd32789817 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -218,7 +218,7 @@ func TestSend(t *testing.T) { } require.NoError(t, err) - err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage) + _, err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage) require.NoError(t, err) time.Sleep(TestMessageTimeout) @@ -273,44 +273,29 @@ func TestExistingStream(t *testing.T) { } require.NoError(t, err) - stream := nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID) - require.Nil(t, stream, "node A should not have an outbound stream") - // node A opens the stream to send the first message - err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage) + stream, err := nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage) require.NoError(t, err) time.Sleep(TestMessageTimeout) require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A") - stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID) - require.NotNil(t, stream, "node A should have an outbound stream") - // node A uses the stream to send a second message - err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage) + err = nodeA.host.writeToStream(stream, testBlockRequestMessage) require.NoError(t, err) require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A") - stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID) - require.NotNil(t, stream, "node B should have an outbound stream") - // node B opens the stream to send the first message - err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage) + stream, err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage) require.NoError(t, err) time.Sleep(TestMessageTimeout) require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B") - stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID) - require.NotNil(t, stream, "node B should have an outbound stream") - // node B uses the stream to send a second message - err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage) + err = nodeB.host.writeToStream(stream, testBlockRequestMessage) require.NoError(t, err) require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B") - - stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID) - require.NotNil(t, stream, "node B should have an outbound stream") } func TestStreamCloseMetadataCleanup(t *testing.T) { @@ -361,13 +346,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { } // node A opens the stream to send the first message - err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake) + _, err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake) require.NoError(t, err) info := nodeA.notificationsProtocols[BlockAnnounceMsgType] // Set handshake data to received - info.handshakeData.Store(nodeB.host.id(), &handshakeData{ + info.handshakeData.Store(nodeB.host.id(), handshakeData{ received: true, validated: true, }) diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 44ff1eddcb..1ed7654b83 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -49,22 +49,25 @@ type ( // NotificationsMessageHandler is called when a (non-handshake) message is received over a notifications stream. NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) error + + streamHandler = func(libp2pnetwork.Stream, peer.ID) ) type notificationsProtocol struct { protocolID protocol.ID getHandshake HandshakeGetter handshakeData *sync.Map //map[peer.ID]*handshakeData + streamHandler streamHandler mapMu sync.RWMutex } -func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (*handshakeData, bool) { +func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (handshakeData, bool) { data, has := n.handshakeData.Load(pid) if !has { - return nil, false + return handshakeData{}, false } - return data.(*handshakeData), true + return data.(handshakeData), true } type handshakeData struct { @@ -72,6 +75,7 @@ type handshakeData struct { validated bool handshake Handshake outboundMsg NotificationsMessage + stream libp2pnetwork.Stream } func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecoder, messageDecoder MessageDecoder) messageDecoder { @@ -123,10 +127,12 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, // if we are the receiver and haven't received the handshake already, validate it if _, has := info.getHandshakeData(peer); !has { logger.Trace("receiver: validating handshake", "protocol", info.protocolID) - info.handshakeData.Store(peer, &handshakeData{ + hsData := handshakeData{ validated: false, received: true, - }) + stream: stream, + } + info.handshakeData.Store(peer, hsData) err := handshakeValidator(peer, hs) if err != nil { @@ -134,8 +140,8 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return errCannotValidateHandshake } - data, _ := info.getHandshakeData(peer) - data.validated = true + hsData.validated = true + info.handshakeData.Store(peer, hsData) // once validated, send back a handshake resp, err := info.getHandshake() @@ -144,7 +150,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return err } - err = s.host.writeToStream(stream, resp) + err = s.host.writeToStream(hsData.stream, resp) if err != nil { logger.Trace("failed to send handshake", "protocol", info.protocolID, "peer", peer, "error", err) return err @@ -160,20 +166,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, if err != nil { logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err) hsData.validated = false + info.handshakeData.Store(peer, hsData) return errCannotValidateHandshake } hsData.validated = true hsData.received = true + info.handshakeData.Store(peer, hsData) + logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer) - } else if hsData.received { - return nil } // if we are the initiator, send the message if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil { logger.Trace("sender: sending message", "protocol", info.protocolID) - err := s.host.writeToStream(stream, hsData.outboundMsg) + err := s.host.writeToStream(hsData.stream, hsData.outboundMsg) if err != nil { logger.Debug("failed to send message", "protocol", info.protocolID, "peer", peer, "error", err) return err @@ -209,6 +216,61 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, } } +func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtocol, msg NotificationsMessage) { + hsData, has := info.getHandshakeData(peer) + if !has || !hsData.received { + hsData = handshakeData{ + validated: false, + received: false, + outboundMsg: msg, + } + + info.handshakeData.Store(peer, hsData) + logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs) + + stream, err := s.host.send(peer, info.protocolID, hs) + if err != nil { + logger.Trace("failed to send message to peer", "peer", peer, "error", err) + return + } + + hsData.stream = stream + info.handshakeData.Store(peer, hsData) + + if info.streamHandler == nil { + return + } + + go info.streamHandler(stream, peer) + return + } + + if s.host.messageCache != nil { + added, err := s.host.messageCache.put(peer, msg) + if err != nil { + logger.Error("failed to add message to cache", "peer", peer, "error", err) + return + } + + if !added { + return + } + } + + if hsData.stream == nil { + logger.Error("trying to send data through empty stream", "protocol", info.protocolID, "peer", peer, "message", msg) + return + } + + // we've already completed the handshake with the peer, send message directly + logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg) + + err := s.host.writeToStream(hsData.stream, msg) + if err != nil { + logger.Trace("failed to send message to peer", "peer", peer, "error", err) + } +} + // gossipExcluding sends a message to each connected peer except the given peer // Used for notifications sub-protocols to gossip a message func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer.ID, msg NotificationsMessage) { @@ -234,35 +296,6 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer continue } - if hsData, has := info.getHandshakeData(peer); !has || !hsData.received { - info.handshakeData.Store(peer, &handshakeData{ - validated: false, - outboundMsg: msg, - }) - - logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs) - err = s.host.send(peer, info.protocolID, hs) - } else { - if s.host.messageCache != nil { - var added bool - added, err = s.host.messageCache.put(peer, msg) - if err != nil { - logger.Error("failed to add message to cache", "peer", peer, "error", err) - continue - } - - if !added { - continue - } - } - - // we've already completed the handshake with the peer, send message directly - logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg) - err = s.host.send(peer, info.protocolID, msg) - } - - if err != nil { - logger.Debug("failed to send message to peer", "peer", peer, "error", err) - } + go s.sendData(peer, hs, info, msg) } } diff --git a/dot/network/notifications_test.go b/dot/network/notifications_test.go index f0f65f793e..0a4b8c1dc4 100644 --- a/dot/network/notifications_test.go +++ b/dot/network/notifications_test.go @@ -53,7 +53,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { // haven't received handshake from peer testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ") - info.handshakeData.Store(testPeerID, &handshakeData{ + info.handshakeData.Store(testPeerID, handshakeData{ received: false, }) @@ -85,6 +85,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { // set handshake data to received hsData, _ := info.getHandshakeData(testPeerID) hsData.received = true + info.handshakeData.Store(testPeerID, hsData) msg, err = decoder(enc, testPeerID) require.NoError(t, err) require.Equal(t, testBlockAnnounce, msg) @@ -139,7 +140,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) { handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage) // set handshake data to received - info.handshakeData.Store(testPeerID, &handshakeData{ + info.handshakeData.Store(testPeerID, handshakeData{ received: true, validated: true, }) diff --git a/dot/network/service.go b/dot/network/service.go index cbcb4623de..747193bace 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -398,6 +398,10 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, decoder := createDecoder(info, handshakeDecoder, messageDecoder) handlerWithValidate := s.createNotificationsMessageHandler(info, handshakeValidator, messageHandler) + streamHandler := func(stream libp2pnetwork.Stream, peerID peer.ID) { + s.readStream(stream, peerID, decoder, handlerWithValidate) + } + np.streamHandler = streamHandler s.host.registerStreamHandlerWithOverwrite(sub, overwriteProtocol, func(stream libp2pnetwork.Stream) { logger.Trace("received stream", "sub-protocol", sub) @@ -408,7 +412,7 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, } p := conn.RemotePeer() - s.readStream(stream, p, decoder, handlerWithValidate) + streamHandler(stream, p) }) logger.Info("registered notifications sub-protocol", "protocol", protocolID) @@ -586,7 +590,7 @@ func (s *Service) NetworkState() common.NetworkState { // Peers returns information about connected peers needed for the rpc server func (s *Service) Peers() []common.PeerInfo { - peers := []common.PeerInfo{} + var peers []common.PeerInfo s.notificationsMu.RLock() np := s.notificationsProtocols[BlockAnnounceMsgType] diff --git a/dot/network/service_test.go b/dot/network/service_test.go index ba86435927..31e88f37a1 100644 --- a/dot/network/service_test.go +++ b/dot/network/service_test.go @@ -17,6 +17,7 @@ package network import ( + "context" "fmt" "os" "strings" @@ -196,12 +197,6 @@ func TestBroadcastDuplicateMessage(t *testing.T) { addrInfosB, err := nodeB.host.addrInfos() require.NoError(t, err) - protocol := nodeA.notificationsProtocols[BlockAnnounceMsgType] - protocol.handshakeData.Store(nodeB.host.id(), &handshakeData{ - received: true, - validated: true, - }) - err = nodeA.host.connect(*addrInfosB[0]) // retry connect if "failed to dial" error if failedToDial(err) { @@ -210,9 +205,21 @@ func TestBroadcastDuplicateMessage(t *testing.T) { } require.NoError(t, err) + stream, err := nodeA.host.h.NewStream(context.Background(), nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID) + require.NoError(t, err) + require.NotNil(t, stream) + + protocol := nodeA.notificationsProtocols[BlockAnnounceMsgType] + protocol.handshakeData.Store(nodeB.host.id(), handshakeData{ + received: true, + validated: true, + stream: stream, + }) + // Only one message will be sent. for i := 0; i < 5; i++ { nodeA.SendMessage(testBlockAnnounceMessage) + time.Sleep(time.Millisecond * 10) } time.Sleep(time.Millisecond * 200) @@ -223,10 +230,8 @@ func TestBroadcastDuplicateMessage(t *testing.T) { // All 5 message will be sent since cache is disabled. for i := 0; i < 5; i++ { nodeA.SendMessage(testBlockAnnounceMessage) - require.NoError(t, err) + time.Sleep(time.Millisecond * 10) } - - time.Sleep(time.Millisecond * 200) require.Equal(t, 6, len(handler.messages[nodeA.host.id()])) } diff --git a/dot/network/test_helpers.go b/dot/network/test_helpers.go index 06c97ef3aa..d185ab7dd8 100644 --- a/dot/network/test_helpers.go +++ b/dot/network/test_helpers.go @@ -99,7 +99,21 @@ func (s *testStreamHandler) handleStream(stream libp2pnetwork.Stream) { func (s *testStreamHandler) handleMessage(stream libp2pnetwork.Stream, msg Message) error { msgs := s.messages[stream.Conn().RemotePeer()] s.messages[stream.Conn().RemotePeer()] = append(msgs, msg) - return nil + return s.writeToStream(stream, testBlockAnnounceHandshake) +} + +func (s *testStreamHandler) writeToStream(stream libp2pnetwork.Stream, msg Message) error { + encMsg, err := msg.Encode() + if err != nil { + return err + } + + msgLen := uint64(len(encMsg)) + lenBytes := uint64ToLEB128(msgLen) + encMsg = append(lenBytes, encMsg...) + + _, err = stream.Write(encMsg) + return err } func (s *testStreamHandler) readStream(stream libp2pnetwork.Stream, peer peer.ID, decoder messageDecoder, handler messageHandler) { @@ -155,6 +169,10 @@ var testBlockAnnounceMessage = &BlockAnnounceMessage{ Number: big.NewInt(128 * 7), } +var testBlockAnnounceHandshake = &BlockAnnounceHandshake{ + BestBlockNumber: 0, +} + func testBlockAnnounceMessageDecoder(in []byte, _ peer.ID) (Message, error) { msg := new(BlockAnnounceMessage) err := msg.Decode(in) diff --git a/dot/network/utils.go b/dot/network/utils.go index f3bf385e7c..74935e3e47 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -136,7 +136,7 @@ func saveKey(priv crypto.PrivKey, fp string) (err error) { } func uint64ToLEB128(in uint64) []byte { - out := []byte{} + var out []byte for { b := uint8(in & 0x7f) in >>= 7 diff --git a/tests/stress/stress_test.go b/tests/stress/stress_test.go index 151a1ca616..2bf9be9d55 100644 --- a/tests/stress/stress_test.go +++ b/tests/stress/stress_test.go @@ -371,6 +371,18 @@ func TestPendingExtrinsic(t *testing.T) { node, err := utils.RunGossamer(t, numNodes-1, utils.TestDir(t, utils.KeyList[numNodes-1]), utils.GenesisDefault, utils.ConfigBABEMaxThreshold, false) require.NoError(t, err) + // Start rest of nodes + nodes, err := utils.InitializeAndStartNodes(t, numNodes-1, utils.GenesisDefault, utils.ConfigNoBABE) + require.NoError(t, err) + nodes = append(nodes, node) + + defer func() { + t.Log("going to tear down gossamer...") + os.Remove(utils.ConfigBABEMaxThreshold) + errList := utils.StopNodes(t, nodes) + require.Len(t, errList, 0) + }() + // send tx to non-authority node api, err := gsrpc.NewSubstrateAPI(fmt.Sprintf("http://localhost:%s", node.RPCPort)) require.NoError(t, err) @@ -426,18 +438,6 @@ func TestPendingExtrinsic(t *testing.T) { require.NoError(t, err) require.NotEqual(t, hash, common.Hash{}) - // Start rest of nodes - nodes, err := utils.InitializeAndStartNodes(t, numNodes-1, utils.GenesisDefault, utils.ConfigNoBABE) - require.NoError(t, err) - nodes = append(nodes, node) - - defer func() { - t.Log("going to tear down gossamer...") - os.Remove(utils.ConfigBABEMaxThreshold) - errList := utils.StopNodes(t, nodes) - require.Len(t, errList, 0) - }() - time.Sleep(time.Second * 10) // wait until there's no more pending extrinsics