diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index d0bd6070..92d827a9 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "sort" - "sync" "testing" "time" @@ -45,9 +44,8 @@ func TestNormalSimultaneousFetch(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) - require.True(t, td.fcm.IsProtected(peers[0])) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String(), requestRecords[1].gsr.ID().String()) require.Equal(t, peers[0], requestRecords[0].p) require.Equal(t, peers[0], requestRecords[1].p) require.False(t, requestRecords[0].gsr.IsCancel()) @@ -90,9 +88,9 @@ func TestNormalSimultaneousFetch(t *testing.T) { td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan1) blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3) - require.True(t, td.fcm.IsProtected(peers[0])) - require.NotContains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.RefuteProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String()) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[1].gsr.ID().String()) moreBlocks := blockChain2.RemainderBlocks(3) moreMetadata := metadataForBlocks(moreBlocks, true) @@ -117,7 +115,7 @@ func TestNormalSimultaneousFetch(t *testing.T) { testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan1) testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan2) - require.False(t, td.fcm.IsProtected(peers[0])) + td.tcm.RefuteProtected(t, peers[0]) } func TestCancelRequestInProgress(t *testing.T) { @@ -135,9 +133,8 @@ func TestCancelRequestInProgress(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) - require.True(t, td.fcm.IsProtected(peers[0])) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String(), requestRecords[1].gsr.ID().String()) firstBlocks := td.blockChain.Blocks(0, 3) firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true) @@ -177,7 +174,7 @@ func TestCancelRequestInProgress(t *testing.T) { _, ok := errors[0].(graphsync.RequestClientCancelledErr) require.True(t, ok) - require.False(t, td.fcm.IsProtected(peers[0])) + td.tcm.RefuteProtected(t, peers[0]) } func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { ctx := context.Background() @@ -200,8 +197,8 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1) - require.True(t, td.fcm.IsProtected(peers[0])) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String()) go func() { firstBlocks := td.blockChain.Blocks(0, 3) @@ -224,7 +221,7 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { require.True(t, rr.gsr.IsCancel()) require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID()) - require.False(t, td.fcm.IsProtected(peers[0])) + td.tcm.RefuteProtected(t, peers[0]) errors := testutil.CollectErrors(requestCtx, t, returnedErrorChan1) require.Len(t, errors, 1) @@ -280,8 +277,8 @@ func TestFailedRequest(t *testing.T) { returnedResponseChan, returnedErrorChan := td.requestManager.NewRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector()) rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] - require.True(t, td.fcm.IsProtected(peers[0])) - require.Contains(t, td.fcm.Protections(peers[0]), rr.gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.AssertProtectedWithTags(t, peers[0], rr.gsr.ID().String()) failedResponses := []gsmsg.GraphSyncResponse{ gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestFailedContentNotFound), @@ -290,7 +287,7 @@ func TestFailedRequest(t *testing.T) { testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan) testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan) - require.False(t, td.fcm.IsProtected(peers[0])) + td.tcm.RefuteProtected(t, peers[0]) } func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) { @@ -987,51 +984,11 @@ func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) g } } -type fakeConnManager struct { - protectedConnsLk sync.RWMutex - protectedConns map[peer.ID][]string -} - -func (fcm *fakeConnManager) Protect(p peer.ID, tag string) { - fcm.protectedConnsLk.Lock() - defer fcm.protectedConnsLk.Unlock() - for _, tagCmp := range fcm.protectedConns[p] { - if tag == tagCmp { - return - } - } - fcm.protectedConns[p] = append(fcm.protectedConns[p], tag) -} - -func (fcm *fakeConnManager) Unprotect(p peer.ID, tag string) bool { - fcm.protectedConnsLk.Lock() - defer fcm.protectedConnsLk.Unlock() - for i, tagCmp := range fcm.protectedConns[p] { - if tag == tagCmp { - fcm.protectedConns[p] = append(fcm.protectedConns[p][:i], fcm.protectedConns[p][i+1:]...) - break - } - } - return len(fcm.protectedConns[p]) > 0 -} - -func (fcm *fakeConnManager) IsProtected(p peer.ID) bool { - fcm.protectedConnsLk.RLock() - defer fcm.protectedConnsLk.RUnlock() - return len(fcm.protectedConns[p]) > 0 -} - -func (fcm *fakeConnManager) Protections(p peer.ID) []string { - fcm.protectedConnsLk.RLock() - defer fcm.protectedConnsLk.RUnlock() - return fcm.protectedConns[p] -} - type testData struct { requestRecordChan chan requestRecord fph *fakePeerHandler fal *testloader.FakeAsyncLoader - fcm *fakeConnManager + tcm *testutil.TestConnManager requestHooks *hooks.OutgoingRequestHooks responseHooks *hooks.IncomingResponseHooks blockHooks *hooks.IncomingBlockHooks @@ -1055,14 +1012,14 @@ func newTestData(ctx context.Context, t *testing.T) *testData { td.requestRecordChan = make(chan requestRecord, 3) td.fph = &fakePeerHandler{td.requestRecordChan} td.fal = testloader.NewFakeAsyncLoader() - td.fcm = &fakeConnManager{protectedConns: make(map[peer.ID][]string)} + td.tcm = testutil.NewTestConnManager() td.requestHooks = hooks.NewRequestHooks() td.responseHooks = hooks.NewResponseHooks() td.blockHooks = hooks.NewBlockHooks() td.networkErrorListeners = listeners.NewNetworkErrorListeners() td.taskqueue = taskqueue.NewTaskQueue(ctx) lsys := cidlink.DefaultLinkSystem() - td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.fcm) + td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.tcm) td.executor = executor.NewExecutor(td.requestManager, td.blockHooks, td.fal.AsyncLoad) td.requestManager.SetDelegate(td.fph) td.requestManager.Startup() diff --git a/testutil/testconnmanager.go b/testutil/testconnmanager.go new file mode 100644 index 00000000..ad9f9c15 --- /dev/null +++ b/testutil/testconnmanager.go @@ -0,0 +1,79 @@ +package testutil + +import ( + "sync" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" +) + +// TestConnManager implements network.ConnManager and allows you to assert +// behavior +type TestConnManager struct { + protectedConnsLk sync.RWMutex + protectedConns map[peer.ID][]string +} + +// NewTestConnManager returns a new TestConnManager +func NewTestConnManager() *TestConnManager { + return &TestConnManager{protectedConns: make(map[peer.ID][]string)} +} + +// Protect simulates protecting a connection (just records occurence) +func (tcm *TestConnManager) Protect(p peer.ID, tag string) { + tcm.protectedConnsLk.Lock() + defer tcm.protectedConnsLk.Unlock() + for _, tagCmp := range tcm.protectedConns[p] { + if tag == tagCmp { + return + } + } + tcm.protectedConns[p] = append(tcm.protectedConns[p], tag) +} + +// Unprotect simulates unprotecting a connection (just records occurence) +func (tcm *TestConnManager) Unprotect(p peer.ID, tag string) bool { + tcm.protectedConnsLk.Lock() + defer tcm.protectedConnsLk.Unlock() + for i, tagCmp := range tcm.protectedConns[p] { + if tag == tagCmp { + tcm.protectedConns[p] = append(tcm.protectedConns[p][:i], tcm.protectedConns[p][i+1:]...) + break + } + } + return len(tcm.protectedConns[p]) > 0 +} + +// AssertProtected asserts that the connection is protected by at least one tag +func (tcm *TestConnManager) AssertProtected(t TestingT, p peer.ID) { + tcm.protectedConnsLk.RLock() + defer tcm.protectedConnsLk.RUnlock() + require.True(t, len(tcm.protectedConns[p]) > 0) +} + +// RefuteProtected refutes that a connection has been protect +func (tcm *TestConnManager) RefuteProtected(t TestingT, p peer.ID) { + tcm.protectedConnsLk.RLock() + defer tcm.protectedConnsLk.RUnlock() + require.False(t, len(tcm.protectedConns[p]) > 0) +} + +// AssertProtectedWithTags verifies the connection is protected with the given +// tags at least +func (tcm *TestConnManager) AssertProtectedWithTags(t TestingT, p peer.ID, tags ...string) { + tcm.protectedConnsLk.RLock() + defer tcm.protectedConnsLk.RUnlock() + for _, tag := range tags { + require.Contains(t, tcm.protectedConns[p], tag) + } +} + +// RefuteProtectedWithTags verifies the connection is not protected with any of the given +// tags +func (tcm *TestConnManager) RefuteProtectedWithTags(t TestingT, p peer.ID, tags ...string) { + tcm.protectedConnsLk.RLock() + defer tcm.protectedConnsLk.RUnlock() + for _, tag := range tags { + require.NotContains(t, tcm.protectedConns[p], tag) + } +}