diff --git a/impl/graphsync.go b/impl/graphsync.go index 8f852a0a..9a6e9292 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -183,7 +183,7 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, requestExecutor := executor.NewExecutor(requestManager, incomingBlockHooks, asyncLoader.AsyncLoad) responseAssembler := responseassembler.New(ctx, peerManager) peerTaskQueue := peertaskqueue.New() - responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests) + responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests, network.ConnectionManager()) graphSync := &GraphSync{ network: network, linkSystem: linkSystem, diff --git a/responsemanager/client.go b/responsemanager/client.go index f93834bf..d0a20dba 100644 --- a/responsemanager/client.go +++ b/responsemanager/client.go @@ -13,6 +13,7 @@ import ( "github.com/ipfs/go-graphsync" "github.com/ipfs/go-graphsync/ipldutil" gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/network" "github.com/ipfs/go-graphsync/notifications" "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync/responsemanager/responseassembler" @@ -28,6 +29,14 @@ const ( thawSpeed = time.Millisecond * 100 ) +type state uint64 + +const ( + queued state = iota + running + paused +) + type inProgressResponseStatus struct { ctx context.Context cancelFn func() @@ -36,7 +45,7 @@ type inProgressResponseStatus struct { traverser ipldutil.Traverser signals ResponseSignals updates []gsmsg.GraphSyncRequest - isPaused bool + state state subscriber *notifications.TopicDataSubscriber } @@ -144,6 +153,7 @@ type ResponseManager struct { qe *queryExecutor inProgressResponses map[responseKey]*inProgressResponseStatus maxInProcessRequests uint64 + connManager network.ConnManager } // New creates a new response manager for responding to requests @@ -160,6 +170,7 @@ func New(ctx context.Context, blockSentListeners BlockSentListeners, networkErrorListeners NetworkErrorListeners, maxInProcessRequests uint64, + connManager network.ConnManager, ) *ResponseManager { ctx, cancelFn := context.WithCancel(ctx) messages := make(chan responseManagerMessage, 16) @@ -181,6 +192,7 @@ func New(ctx context.Context, workSignal: workSignal, inProgressResponses: make(map[responseKey]*inProgressResponseStatus), maxInProcessRequests: maxInProcessRequests, + connManager: connManager, } rm.qe = &queryExecutor{ blockHooks: blockHooks, @@ -192,6 +204,7 @@ func New(ctx context.Context, ctx: ctx, workSignal: workSignal, ticker: time.NewTicker(thawSpeed), + connManager: connManager, } return rm } diff --git a/responsemanager/queryexecutor.go b/responsemanager/queryexecutor.go index 14c2d9ea..18ae4f3a 100644 --- a/responsemanager/queryexecutor.go +++ b/responsemanager/queryexecutor.go @@ -12,6 +12,7 @@ import ( "github.com/ipfs/go-graphsync" "github.com/ipfs/go-graphsync/ipldutil" gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/network" "github.com/ipfs/go-graphsync/notifications" "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync/responsemanager/responseassembler" @@ -39,6 +40,7 @@ type queryExecutor struct { ctx context.Context workSignal chan struct{} ticker *time.Ticker + connManager network.ConnManager } func (qe *queryExecutor) processQueriesWorker() { @@ -73,6 +75,7 @@ func (qe *queryExecutor) processQueriesWorker() { _, err := qe.executeQuery(pid, taskData.Request, taskData.Loader, taskData.Traverser, taskData.Signals, taskData.Subscriber) isCancelled := err != nil && isContextErr(err) if isCancelled { + qe.connManager.Unprotect(pid, taskData.Request.ID().String()) qe.cancelledListeners.NotifyCancelledListeners(pid, taskData.Request) } qe.manager.FinishTask(task, err) diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index a9b29fa8..8c59e821 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -45,6 +45,7 @@ func TestIncomingQuery(t *testing.T) { qhc := make(chan *queuedHook, 1) td.requestQueuedHooks.Register(func(p peer.ID, request graphsync.RequestData) { + td.connManager.AssertProtectedWithTags(t, p, request.ID().String()) qhc <- &queuedHook{ p: p, request: request, @@ -54,15 +55,16 @@ func TestIncomingQuery(t *testing.T) { responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) - testutil.AssertDoesReceive(td.ctx, t, td.completedRequestChan, "Should have completed request but didn't") for i := 0; i < len(blks); i++ { td.assertSendBlock() } + td.assertCompleteRequestWith(graphsync.RequestCompletedFull) // ensure request queued hook fires. out := <-qhc require.Equal(t, td.p, out.p) require.Equal(t, out.request.ID(), td.requestID) + td.connManager.RefuteProtected(t, td.p) } func TestCancellationQueryInProgress(t *testing.T) { @@ -72,6 +74,7 @@ func TestCancellationQueryInProgress(t *testing.T) { td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) cancelledListenerCalled := make(chan struct{}, 1) td.cancelledListeners.Register(func(p peer.ID, request graphsync.RequestData) { + td.connManager.RefuteProtected(t, td.p) cancelledListenerCalled <- struct{}{} }) responseManager.Startup() @@ -108,6 +111,7 @@ func TestCancellationViaCommand(t *testing.T) { require.NoError(t, err) td.assertCompleteRequestWith(graphsync.RequestCancelled) + td.connManager.RefuteProtected(t, td.p) } func TestEarlyCancellation(t *testing.T) { @@ -118,6 +122,9 @@ func TestEarlyCancellation(t *testing.T) { td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) + responseManager.synchronize() + + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) // send a cancellation cancelRequests := []gsmsg.GraphSyncRequest{ @@ -131,6 +138,7 @@ func TestEarlyCancellation(t *testing.T) { td.queryQueue.popWait.Done() td.assertNoResponses() + td.connManager.RefuteProtected(t, td.p) } func TestMissingContent(t *testing.T) { t.Run("missing root block", func(t *testing.T) { @@ -174,6 +182,7 @@ func TestValidationAndExtensions(t *testing.T) { responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) td.assertCompleteRequestWith(graphsync.RequestRejected) + td.connManager.RefuteProtected(t, td.p) }) t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) { @@ -182,11 +191,13 @@ func TestValidationAndExtensions(t *testing.T) { responseManager := td.newResponseManager() responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) hookActions.SendExtensionData(td.extensionResponse) }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) td.assertCompleteRequestWith(graphsync.RequestRejected) td.assertReceiveExtensionResponse() + td.connManager.RefuteProtected(t, td.p) }) t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) { @@ -195,12 +206,14 @@ func TestValidationAndExtensions(t *testing.T) { responseManager := td.newResponseManager() responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) hookActions.ValidateRequest() hookActions.SendExtensionData(td.extensionResponse) }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) td.assertCompleteRequestWith(graphsync.RequestCompletedFull) td.assertReceiveExtensionResponse() + td.connManager.RefuteProtected(t, td.p) }) t.Run("if any hook fails, should fail", func(t *testing.T) { @@ -962,6 +975,7 @@ type testData struct { completedResponseStatuses chan graphsync.ResponseStatusCode networkErrorChan chan error allBlocks []blocks.Block + connManager *testutil.TestConnManager } func newTestData(t *testing.T) testData { @@ -1049,17 +1063,18 @@ func newTestData(t *testing.T) testData { default: } }) + td.connManager = testutil.NewTestConnManager() return td } func (td *testData) newResponseManager() *ResponseManager { - return New(td.ctx, td.persistence, td.responseAssembler, td.queryQueue, td.requestQueuedHooks, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners, td.blockSentListeners, td.networkErrorListeners, 6) + return New(td.ctx, td.persistence, td.responseAssembler, td.queryQueue, td.requestQueuedHooks, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners, td.blockSentListeners, td.networkErrorListeners, 6, td.connManager) } func (td *testData) alternateLoaderResponseManager() *ResponseManager { obs := make(map[ipld.Link][]byte) persistence := testutil.NewTestStore(obs) - return New(td.ctx, persistence, td.responseAssembler, td.queryQueue, td.requestQueuedHooks, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners, td.blockSentListeners, td.networkErrorListeners, 6) + return New(td.ctx, persistence, td.responseAssembler, td.queryQueue, td.requestQueuedHooks, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners, td.blockSentListeners, td.networkErrorListeners, 6, td.connManager) } func (td *testData) assertPausedRequest() { diff --git a/responsemanager/server.go b/responsemanager/server.go index c06f3f3e..9ebc307a 100644 --- a/responsemanager/server.go +++ b/responsemanager/server.go @@ -46,7 +46,7 @@ func (rm *ResponseManager) processUpdate(key responseKey, update gsmsg.GraphSync log.Warnf("received update for non existent request, peer %s, request ID %d", key.p.Pretty(), key.requestID) return } - if !response.isPaused { + if response.state != paused { response.updates = append(response.updates, update) select { case response.signals.UpdateSignal <- struct{}{}: @@ -88,10 +88,10 @@ func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.Request if !ok { return errors.New("could not find request") } - if !inProgressResponse.isPaused { + if inProgressResponse.state != paused { return errors.New("request is not paused") } - inProgressResponse.isPaused = false + inProgressResponse.state = queued if len(extensions) > 0 { _ = rm.responseAssembler.Transaction(p, requestID, func(rb responseassembler.ResponseBuilder) error { for _, extension := range extensions { @@ -116,10 +116,10 @@ func (rm *ResponseManager) abortRequest(p peer.ID, requestID graphsync.RequestID return errors.New("could not find request") } - if response.isPaused { + if response.state != running { _ = rm.responseAssembler.Transaction(p, requestID, func(rb responseassembler.ResponseBuilder) error { if isContextErr(err) { - + rm.connManager.Unprotect(p, requestID.String()) rm.cancelledListeners.NotifyCancelledListeners(p, response.request) rb.ClearRequest() } else if err == errNetworkError { @@ -152,6 +152,7 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync rm.processUpdate(key, request) continue } + rm.connManager.Protect(p, request.ID().String()) rm.requestQueuedHooks.ProcessRequestQueuedHooks(p, request) ctx, cancelFn := context.WithCancel(rm.ctx) sub := notifications.NewTopicDataSubscriber(&subscriber{ @@ -162,6 +163,7 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync blockSentListeners: rm.blockSentListeners, completedListeners: rm.completedListeners, networkErrorListeners: rm.networkErrorListeners, + connManager: rm.connManager, }) rm.inProgressResponses[key] = @@ -175,6 +177,7 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync UpdateSignal: make(chan struct{}, 1), ErrSignal: make(chan error, 1), }, + state: queued, } // TODO: Use a better work estimation metric. @@ -202,10 +205,11 @@ func (rm *ResponseManager) taskDataForKey(key responseKey) ResponseTaskData { response.loader = loader response.traverser = traverser if isPaused { - response.isPaused = true + response.state = paused return ResponseTaskData{Empty: true} } } + response.state = running return ResponseTaskData{false, response.subscriber, response.ctx, response.request, response.loader, response.traverser, response.signals} } @@ -226,7 +230,7 @@ func (rm *ResponseManager) finishTask(task *peertask.Task, err error) { return } if _, ok := err.(hooks.ErrPaused); ok { - response.isPaused = true + response.state = paused return } if err != nil { @@ -252,7 +256,7 @@ func (rm *ResponseManager) pauseRequest(p peer.ID, requestID graphsync.RequestID if !ok { return errors.New("could not find request") } - if inProgressResponse.isPaused { + if inProgressResponse.state == paused { return errors.New("request is already paused") } select { diff --git a/responsemanager/subscriber.go b/responsemanager/subscriber.go index 1bed9f85..b5f3a1f3 100644 --- a/responsemanager/subscriber.go +++ b/responsemanager/subscriber.go @@ -9,6 +9,7 @@ import ( "github.com/ipfs/go-graphsync" gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/messagequeue" + "github.com/ipfs/go-graphsync/network" "github.com/ipfs/go-graphsync/notifications" ) @@ -22,6 +23,7 @@ type subscriber struct { blockSentListeners BlockSentListeners networkErrorListeners NetworkErrorListeners completedListeners CompletedListeners + connManager network.ConnManager } func (s *subscriber) OnNext(topic notifications.Topic, event notifications.Event) { @@ -45,6 +47,7 @@ func (s *subscriber) OnNext(topic notifications.Topic, event notifications.Event } status, isStatus := topic.(graphsync.ResponseStatusCode) if isStatus { + s.connManager.Unprotect(s.p, s.request.ID().String()) switch responseEvent.Name { case messagequeue.Error: s.networkErrorListeners.NotifyNetworkErrorListeners(s.p, s.request, responseEvent.Err)