Skip to content

Commit

Permalink
feat: fire network error when network disconnects during request (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkmc authored Apr 15, 2021
1 parent 611735c commit 6e60b85
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 8 deletions.
51 changes: 45 additions & 6 deletions requestmanager/requestmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"github.com/hannahhoward/go-pubsub"
"golang.org/x/xerrors"
"sync/atomic"

blocks "github.com/ipfs/go-block-format"
Expand Down Expand Up @@ -70,6 +72,7 @@ type RequestManager struct {
peerHandler PeerHandler
rc *responseCollector
asyncLoader AsyncLoader
disconnectNotif *pubsub.PubSub
// dont touch out side of run loop
nextRequestID graphsync.RequestID
inProgressRequestStatuses map[graphsync.RequestID]*inProgressRequestStatus
Expand Down Expand Up @@ -111,6 +114,7 @@ func New(ctx context.Context,
ctx: ctx,
cancel: cancel,
asyncLoader: asyncLoader,
disconnectNotif: pubsub.New(disconnectDispatcher),
rc: newResponseCollector(ctx),
messages: make(chan requestManagerMessage, 16),
inProgressRequestStatuses: make(map[graphsync.RequestID]*inProgressRequestStatus),
Expand All @@ -128,6 +132,7 @@ func (rm *RequestManager) SetDelegate(peerHandler PeerHandler) {

type inProgressRequest struct {
requestID graphsync.RequestID
request gsmsg.GraphSyncRequest
incoming chan graphsync.ResponseProgress
incomingError chan error
}
Expand Down Expand Up @@ -166,14 +171,46 @@ func (rm *RequestManager) SendRequest(ctx context.Context,
case receivedInProgressRequest = <-inProgressRequestChan:
}

// If the connection to the peer is disconnected, fire an error
unsub := rm.listenForDisconnect(p, func(neterr error) {
rm.networkErrorListeners.NotifyNetworkErrorListeners(p, receivedInProgressRequest.request, neterr)
})

return rm.rc.collectResponses(ctx,
receivedInProgressRequest.incoming,
receivedInProgressRequest.incomingError,
func() {
rm.cancelRequest(receivedInProgressRequest.requestID,
receivedInProgressRequest.incoming,
receivedInProgressRequest.incomingError)
})
},
// Once the request has completed, stop listening for disconnect events
unsub,
)
}

// Dispatch the Disconnect event to subscribers
func disconnectDispatcher(p pubsub.Event, subscriberFn pubsub.SubscriberFn) error {
listener := subscriberFn.(func(peer.ID))
listener(p.(peer.ID))
return nil
}

// Listen for the Disconnect event for the given peer
func (rm *RequestManager) listenForDisconnect(p peer.ID, onDisconnect func(neterr error)) func() {
// Subscribe to Disconnect notifications
return rm.disconnectNotif.Subscribe(func(evtPeer peer.ID) {
// If the peer is the one we're interested in, call the listener
if evtPeer == p {
onDisconnect(xerrors.Errorf("disconnected from peer %s", p))
}
})
}

// Disconnected is called when a peer disconnects
func (rm *RequestManager) Disconnected(p peer.ID) {
// Notify any listeners that a peer has disconnected
rm.disconnectNotif.Publish(p)
}

func (rm *RequestManager) emptyResponse() (chan graphsync.ResponseProgress, chan error) {
Expand Down Expand Up @@ -311,17 +348,19 @@ type terminateRequestMessage struct {
requestID graphsync.RequestID
}

func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (chan graphsync.ResponseProgress, chan error) {
func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (gsmsg.GraphSyncRequest, chan graphsync.ResponseProgress, chan error) {
request, hooksResult, err := rm.validateRequest(requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions)
if err != nil {
return rm.singleErrorResponse(err)
rp, err := rm.singleErrorResponse(err)
return request, rp, err
}
doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs)
var doNotSendCids *cid.Set
if has {
doNotSendCids, err = cidset.DecodeCidSet(doNotSendCidsData)
if err != nil {
return rm.singleErrorResponse(err)
rp, err := rm.singleErrorResponse(err)
return request, rp, err
}
} else {
doNotSendCids = cid.NewSet()
Expand Down Expand Up @@ -355,14 +394,14 @@ func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *Re
ResumeMessages: resumeMessages,
PauseMessages: pauseMessages,
})
return incoming, incomingError
return request, incoming, incomingError
}

func (nrm *newRequestMessage) handle(rm *RequestManager) {
var ipr inProgressRequest
ipr.requestID = rm.nextRequestID
rm.nextRequestID++
ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm)
ipr.request, ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm)

select {
case nrm.inProgressRequestChan <- ipr:
Expand Down
36 changes: 36 additions & 0 deletions requestmanager/requestmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,42 @@ func TestRequestReturnsMissingBlocks(t *testing.T) {
require.NotEqual(t, len(errs), 0, "did not send errors")
}

func TestDisconnectNotification(t *testing.T) {
ctx := context.Background()
td := newTestData(ctx, t)
requestCtx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
peers := testutil.GeneratePeers(2)

// Listen for network errors
networkErrors := make(chan peer.ID, 1)
td.networkErrorListeners.Register(func(p peer.ID, request graphsync.RequestData, err error) {
networkErrors <- p
})

// Send a request to the target peer
targetPeer := peers[0]
td.requestManager.SendRequest(requestCtx, targetPeer, td.blockChain.TipLink, td.blockChain.Selector())

// Disconnect a random peer, should not fire any events
randomPeer := peers[1]
td.requestManager.Disconnected(randomPeer)
select {
case <-networkErrors:
t.Fatal("should not fire network error when unrelated peer disconnects")
default:
}

// Disconnect the target peer, should fire a network error
td.requestManager.Disconnected(targetPeer)
select {
case p:= <-networkErrors:
require.Equal(t, p, targetPeer)
default:
t.Fatal("should fire network error when peer disconnects")
}
}

func TestEncodingExtensions(t *testing.T) {
ctx := context.Background()
td := newTestData(ctx, t)
Expand Down
5 changes: 4 additions & 1 deletion requestmanager/responsecollector.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ func (rc *responseCollector) collectResponses(
requestCtx context.Context,
incomingResponses <-chan graphsync.ResponseProgress,
incomingErrors <-chan error,
cancelRequest func()) (<-chan graphsync.ResponseProgress, <-chan error) {
cancelRequest func(),
onComplete func(),
) (<-chan graphsync.ResponseProgress, <-chan error) {

returnedResponses := make(chan graphsync.ResponseProgress)
returnedErrors := make(chan error)

go func() {
var receivedResponses []graphsync.ResponseProgress
defer close(returnedResponses)
defer onComplete()
outgoingResponses := func() chan<- graphsync.ResponseProgress {
if len(receivedResponses) == 0 {
return nil
Expand Down
2 changes: 1 addition & 1 deletion requestmanager/responsecollector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestBufferingResponseProgress(t *testing.T) {
cancelRequest := func() {}

outgoingResponses, outgoingErrors := rc.collectResponses(
requestCtx, incomingResponses, incomingErrors, cancelRequest)
requestCtx, incomingResponses, incomingErrors, cancelRequest, func(){})

blockStore := make(map[ipld.Link][]byte)
loader, storer := testutil.NewTestStore(blockStore)
Expand Down

0 comments on commit 6e60b85

Please sign in to comment.