Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dot/network): Fix notification handshake and reuse stream. #1545

Merged
merged 8 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dot/network/block_announce.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,15 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err
// `createNotificationsMessageHandler` which locks the map beforehand.
data, ok := np.getHandshakeData(peer)
if !ok {
np.handshakeData.Store(peer, &handshakeData{
np.handshakeData.Store(peer, handshakeData{
received: true,
validated: true,
})
data, _ = np.getHandshakeData(peer)
}

data.handshake = hs
np.handshakeData.Store(peer, data)

// if peer has higher best block than us, begin syncing
latestHeader, err := s.blockState.BestBlockHeader()
Expand Down
2 changes: 1 addition & 1 deletion dot/network/block_announce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestValidateBlockAnnounceHandshake(t *testing.T) {
handshakeData: new(sync.Map),
}
testPeerID := peer.ID("noot")
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, &handshakeData{})
nodeA.notificationsProtocols[BlockAnnounceMsgType].handshakeData.Store(testPeerID, handshakeData{})

err := nodeA.validateBlockAnnounceHandshake(testPeerID, &BlockAnnounceHandshake{
BestBlockNumber: 100,
Expand Down
2 changes: 1 addition & 1 deletion dot/network/gossip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestGossip(t *testing.T) {
}
require.NoError(t, err)

err = nodeA.host.send(addrInfosB[0].ID, "", testBlockAnnounceMessage)
_, err = nodeA.host.send(addrInfosB[0].ID, "", testBlockAnnounceMessage)
require.NoError(t, err)

time.Sleep(TestMessageTimeout)
Expand Down
42 changes: 18 additions & 24 deletions dot/network/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,32 +262,26 @@ func (h *host) bootstrap() {
}
}

// send writes the given message to the outbound message stream for the given
// peer (gets the already opened outbound message stream or opens a new one).
func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) {
// get outbound stream for given peer
s := h.getOutboundStream(p, pid)
noot marked this conversation as resolved.
Show resolved Hide resolved

// check if stream needs to be opened
if s == nil {
// open outbound stream with host protocol id
s, err = h.h.NewStream(h.ctx, p, pid)
if err != nil {
logger.Trace("failed to open new stream with peer", "peer", p, "protocol", pid, "error", err)
return err
}

logger.Trace(
"Opened stream",
"host", h.id(),
"peer", p,
"protocol", pid,
)
// send creates a new outbound stream with the given peer and writes the message. It also returns
// the newly created stream.
func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (libp2pnetwork.Stream, error) {
// open outbound stream with host protocol id
stream, err := h.h.NewStream(h.ctx, p, pid)
if err != nil {
logger.Trace("failed to open new stream with peer", "peer", p, "protocol", pid, "error", err)
return nil, err
}

err = h.writeToStream(s, msg)
logger.Trace(
"Opened stream",
"host", h.id(),
"peer", p,
"protocol", pid,
)

err = h.writeToStream(stream, msg)
if err != nil {
return err
return nil, err
}

logger.Trace(
Expand All @@ -298,7 +292,7 @@ func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) {
"message", msg.String(),
)

return nil
return stream, nil
}

func (h *host) writeToStream(s libp2pnetwork.Stream, msg Message) error {
Expand Down
29 changes: 7 additions & 22 deletions dot/network/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func TestSend(t *testing.T) {
}
require.NoError(t, err)

err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
_, err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
require.NoError(t, err)

time.Sleep(TestMessageTimeout)
Expand Down Expand Up @@ -273,44 +273,29 @@ func TestExistingStream(t *testing.T) {
}
require.NoError(t, err)

stream := nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
require.Nil(t, stream, "node A should not have an outbound stream")

// node A opens the stream to send the first message
err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
stream, err := nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
require.NoError(t, err)

time.Sleep(TestMessageTimeout)
require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A")

stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
require.NotNil(t, stream, "node A should have an outbound stream")

// node A uses the stream to send a second message
err = nodeA.host.send(addrInfosB[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
err = nodeA.host.writeToStream(stream, testBlockRequestMessage)
require.NoError(t, err)
require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A")

stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID)
require.NotNil(t, stream, "node B should have an outbound stream")

// node B opens the stream to send the first message
err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
stream, err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
require.NoError(t, err)

time.Sleep(TestMessageTimeout)
require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B")

stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID)
require.NotNil(t, stream, "node B should have an outbound stream")

// node B uses the stream to send a second message
err = nodeB.host.send(addrInfosA[0].ID, nodeB.host.protocolID, testBlockRequestMessage)
err = nodeB.host.writeToStream(stream, testBlockRequestMessage)
require.NoError(t, err)
require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B")

stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID)
require.NotNil(t, stream, "node B should have an outbound stream")
}

func TestStreamCloseMetadataCleanup(t *testing.T) {
Expand Down Expand Up @@ -361,13 +346,13 @@ func TestStreamCloseMetadataCleanup(t *testing.T) {
}

// node A opens the stream to send the first message
err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake)
_, err = nodeA.host.send(nodeB.host.id(), nodeB.host.protocolID+blockAnnounceID, testHandshake)
require.NoError(t, err)

info := nodeA.notificationsProtocols[BlockAnnounceMsgType]

// Set handshake data to received
info.handshakeData.Store(nodeB.host.id(), &handshakeData{
info.handshakeData.Store(nodeB.host.id(), handshakeData{
received: true,
validated: true,
})
Expand Down
115 changes: 74 additions & 41 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,33 @@ type (

// NotificationsMessageHandler is called when a (non-handshake) message is received over a notifications stream.
NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) error

streamHandler = func(libp2pnetwork.Stream, peer.ID)
)

type notificationsProtocol struct {
protocolID protocol.ID
getHandshake HandshakeGetter
handshakeData *sync.Map //map[peer.ID]*handshakeData
streamHandler streamHandler
mapMu sync.RWMutex
}

func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (*handshakeData, bool) {
func (n *notificationsProtocol) getHandshakeData(pid peer.ID) (handshakeData, bool) {
data, has := n.handshakeData.Load(pid)
if !has {
return nil, false
return handshakeData{}, false
}

return data.(*handshakeData), true
return data.(handshakeData), true
}

type handshakeData struct {
received bool
validated bool
handshake Handshake
outboundMsg NotificationsMessage
stream libp2pnetwork.Stream
}

func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecoder, messageDecoder MessageDecoder) messageDecoder {
Expand Down Expand Up @@ -123,19 +127,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
// if we are the receiver and haven't received the handshake already, validate it
if _, has := info.getHandshakeData(peer); !has {
logger.Trace("receiver: validating handshake", "protocol", info.protocolID)
info.handshakeData.Store(peer, &handshakeData{
hsData := handshakeData{
validated: false,
received: true,
})
stream: stream,
}
info.handshakeData.Store(peer, hsData)

err := handshakeValidator(peer, hs)
if err != nil {
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
return errCannotValidateHandshake
}

data, _ := info.getHandshakeData(peer)
data.validated = true
hsData.validated = true
info.handshakeData.Store(peer, hsData)

// once validated, send back a handshake
resp, err := info.getHandshake()
Expand All @@ -144,7 +150,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
return err
}

err = s.host.writeToStream(stream, resp)
err = s.host.writeToStream(hsData.stream, resp)
if err != nil {
logger.Trace("failed to send handshake", "protocol", info.protocolID, "peer", peer, "error", err)
return err
Expand All @@ -160,20 +166,21 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
if err != nil {
logger.Trace("failed to validate handshake", "protocol", info.protocolID, "peer", peer, "error", err)
hsData.validated = false
info.handshakeData.Store(peer, hsData)
return errCannotValidateHandshake
}

hsData.validated = true
hsData.received = true
info.handshakeData.Store(peer, hsData)

logger.Trace("sender: validated handshake", "protocol", info.protocolID, "peer", peer)
} else if hsData.received {
return nil
}

// if we are the initiator, send the message
if hsData, has := info.getHandshakeData(peer); has && hsData.validated && hsData.received && hsData.outboundMsg != nil {
logger.Trace("sender: sending message", "protocol", info.protocolID)
err := s.host.writeToStream(stream, hsData.outboundMsg)
err := s.host.writeToStream(hsData.stream, hsData.outboundMsg)
if err != nil {
logger.Debug("failed to send message", "protocol", info.protocolID, "peer", peer, "error", err)
return err
Expand Down Expand Up @@ -209,6 +216,61 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
}
}

func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtocol, msg NotificationsMessage) {
hsData, has := info.getHandshakeData(peer)
if !has || !hsData.received {
hsData = handshakeData{
validated: false,
received: false,
outboundMsg: msg,
}

info.handshakeData.Store(peer, hsData)
logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)

stream, err := s.host.send(peer, info.protocolID, hs)
if err != nil {
logger.Trace("failed to send message to peer", "peer", peer, "error", err)
return
}

hsData.stream = stream
info.handshakeData.Store(peer, hsData)

if info.streamHandler == nil {
return
}

go info.streamHandler(stream, peer)
return
}

if s.host.messageCache != nil {
added, err := s.host.messageCache.put(peer, msg)
if err != nil {
logger.Error("failed to add message to cache", "peer", peer, "error", err)
return
}

if !added {
return
}
}

if hsData.stream == nil {
logger.Error("trying to send data through empty stream", "protocol", info.protocolID, "peer", peer, "message", msg)
return
}

// we've already completed the handshake with the peer, send message directly
logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg)

err := s.host.writeToStream(hsData.stream, msg)
if err != nil {
logger.Trace("failed to send message to peer", "peer", peer, "error", err)
}
}

// gossipExcluding sends a message to each connected peer except the given peer
// Used for notifications sub-protocols to gossip a message
func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer.ID, msg NotificationsMessage) {
Expand All @@ -234,35 +296,6 @@ func (s *Service) broadcastExcluding(info *notificationsProtocol, excluding peer
continue
}

if hsData, has := info.getHandshakeData(peer); !has || !hsData.received {
info.handshakeData.Store(peer, &handshakeData{
validated: false,
outboundMsg: msg,
})

logger.Trace("sending handshake", "protocol", info.protocolID, "peer", peer, "message", hs)
err = s.host.send(peer, info.protocolID, hs)
} else {
if s.host.messageCache != nil {
var added bool
added, err = s.host.messageCache.put(peer, msg)
if err != nil {
logger.Error("failed to add message to cache", "peer", peer, "error", err)
continue
}

if !added {
continue
}
}

// we've already completed the handshake with the peer, send message directly
logger.Trace("sending message", "protocol", info.protocolID, "peer", peer, "message", msg)
err = s.host.send(peer, info.protocolID, msg)
}

if err != nil {
logger.Debug("failed to send message to peer", "peer", peer, "error", err)
}
go s.sendData(peer, hs, info, msg)
}
}
5 changes: 3 additions & 2 deletions dot/network/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {

// haven't received handshake from peer
testPeerID := peer.ID("QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ")
info.handshakeData.Store(testPeerID, &handshakeData{
info.handshakeData.Store(testPeerID, handshakeData{
received: false,
})

Expand Down Expand Up @@ -85,6 +85,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) {
// set handshake data to received
hsData, _ := info.getHandshakeData(testPeerID)
hsData.received = true
info.handshakeData.Store(testPeerID, hsData)
msg, err = decoder(enc, testPeerID)
require.NoError(t, err)
require.Equal(t, testBlockAnnounce, msg)
Expand Down Expand Up @@ -139,7 +140,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounce(t *testing.T) {
handler := s.createNotificationsMessageHandler(info, s.validateBlockAnnounceHandshake, s.handleBlockAnnounceMessage)

// set handshake data to received
info.handshakeData.Store(testPeerID, &handshakeData{
info.handshakeData.Store(testPeerID, handshakeData{
received: true,
validated: true,
})
Expand Down
Loading