Skip to content

Commit

Permalink
refactor(testutil): extract TestConnManager
Browse files Browse the repository at this point in the history
  • Loading branch information
hannahhoward committed Sep 29, 2021
1 parent cae08eb commit 94eda95
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 61 deletions.
79 changes: 18 additions & 61 deletions requestmanager/requestmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"sort"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -45,9 +44,8 @@ func TestNormalSimultaneousFetch(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2)

require.True(t, td.fcm.IsProtected(peers[0]))
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String())
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String())
td.tcm.AssertProtected(t, peers[0])
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String(), requestRecords[1].gsr.ID().String())
require.Equal(t, peers[0], requestRecords[0].p)
require.Equal(t, peers[0], requestRecords[1].p)
require.False(t, requestRecords[0].gsr.IsCancel())
Expand Down Expand Up @@ -90,9 +88,9 @@ func TestNormalSimultaneousFetch(t *testing.T) {
td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan1)
blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3)

require.True(t, td.fcm.IsProtected(peers[0]))
require.NotContains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String())
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String())
td.tcm.AssertProtected(t, peers[0])
td.tcm.RefuteProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String())
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[1].gsr.ID().String())

moreBlocks := blockChain2.RemainderBlocks(3)
moreMetadata := metadataForBlocks(moreBlocks, true)
Expand All @@ -117,7 +115,7 @@ func TestNormalSimultaneousFetch(t *testing.T) {
testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan1)
testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan2)

require.False(t, td.fcm.IsProtected(peers[0]))
td.tcm.RefuteProtected(t, peers[0])
}

func TestCancelRequestInProgress(t *testing.T) {
Expand All @@ -135,9 +133,8 @@ func TestCancelRequestInProgress(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2)

require.True(t, td.fcm.IsProtected(peers[0]))
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String())
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String())
td.tcm.AssertProtected(t, peers[0])
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String(), requestRecords[1].gsr.ID().String())

firstBlocks := td.blockChain.Blocks(0, 3)
firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true)
Expand Down Expand Up @@ -177,7 +174,7 @@ func TestCancelRequestInProgress(t *testing.T) {
_, ok := errors[0].(graphsync.RequestClientCancelledErr)
require.True(t, ok)

require.False(t, td.fcm.IsProtected(peers[0]))
td.tcm.RefuteProtected(t, peers[0])
}
func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
ctx := context.Background()
Expand All @@ -200,8 +197,8 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)

require.True(t, td.fcm.IsProtected(peers[0]))
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String())
td.tcm.AssertProtected(t, peers[0])
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String())

go func() {
firstBlocks := td.blockChain.Blocks(0, 3)
Expand All @@ -224,7 +221,7 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
require.True(t, rr.gsr.IsCancel())
require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID())

require.False(t, td.fcm.IsProtected(peers[0]))
td.tcm.RefuteProtected(t, peers[0])

errors := testutil.CollectErrors(requestCtx, t, returnedErrorChan1)
require.Len(t, errors, 1)
Expand Down Expand Up @@ -280,8 +277,8 @@ func TestFailedRequest(t *testing.T) {
returnedResponseChan, returnedErrorChan := td.requestManager.NewRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector())

rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0]
require.True(t, td.fcm.IsProtected(peers[0]))
require.Contains(t, td.fcm.Protections(peers[0]), rr.gsr.ID().String())
td.tcm.AssertProtected(t, peers[0])
td.tcm.AssertProtectedWithTags(t, peers[0], rr.gsr.ID().String())

failedResponses := []gsmsg.GraphSyncResponse{
gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestFailedContentNotFound),
Expand All @@ -290,7 +287,7 @@ func TestFailedRequest(t *testing.T) {

testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan)
testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan)
require.False(t, td.fcm.IsProtected(peers[0]))
td.tcm.RefuteProtected(t, peers[0])
}

func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) {
Expand Down Expand Up @@ -987,51 +984,11 @@ func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) g
}
}

type fakeConnManager struct {
protectedConnsLk sync.RWMutex
protectedConns map[peer.ID][]string
}

func (fcm *fakeConnManager) Protect(p peer.ID, tag string) {
fcm.protectedConnsLk.Lock()
defer fcm.protectedConnsLk.Unlock()
for _, tagCmp := range fcm.protectedConns[p] {
if tag == tagCmp {
return
}
}
fcm.protectedConns[p] = append(fcm.protectedConns[p], tag)
}

func (fcm *fakeConnManager) Unprotect(p peer.ID, tag string) bool {
fcm.protectedConnsLk.Lock()
defer fcm.protectedConnsLk.Unlock()
for i, tagCmp := range fcm.protectedConns[p] {
if tag == tagCmp {
fcm.protectedConns[p] = append(fcm.protectedConns[p][:i], fcm.protectedConns[p][i+1:]...)
break
}
}
return len(fcm.protectedConns[p]) > 0
}

func (fcm *fakeConnManager) IsProtected(p peer.ID) bool {
fcm.protectedConnsLk.RLock()
defer fcm.protectedConnsLk.RUnlock()
return len(fcm.protectedConns[p]) > 0
}

func (fcm *fakeConnManager) Protections(p peer.ID) []string {
fcm.protectedConnsLk.RLock()
defer fcm.protectedConnsLk.RUnlock()
return fcm.protectedConns[p]
}

type testData struct {
requestRecordChan chan requestRecord
fph *fakePeerHandler
fal *testloader.FakeAsyncLoader
fcm *fakeConnManager
tcm *testutil.TestConnManager
requestHooks *hooks.OutgoingRequestHooks
responseHooks *hooks.IncomingResponseHooks
blockHooks *hooks.IncomingBlockHooks
Expand All @@ -1055,14 +1012,14 @@ func newTestData(ctx context.Context, t *testing.T) *testData {
td.requestRecordChan = make(chan requestRecord, 3)
td.fph = &fakePeerHandler{td.requestRecordChan}
td.fal = testloader.NewFakeAsyncLoader()
td.fcm = &fakeConnManager{protectedConns: make(map[peer.ID][]string)}
td.tcm = testutil.NewTestConnManager()
td.requestHooks = hooks.NewRequestHooks()
td.responseHooks = hooks.NewResponseHooks()
td.blockHooks = hooks.NewBlockHooks()
td.networkErrorListeners = listeners.NewNetworkErrorListeners()
td.taskqueue = taskqueue.NewTaskQueue(ctx)
lsys := cidlink.DefaultLinkSystem()
td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.fcm)
td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.tcm)
td.executor = executor.NewExecutor(td.requestManager, td.blockHooks, td.fal.AsyncLoad)
td.requestManager.SetDelegate(td.fph)
td.requestManager.Startup()
Expand Down
79 changes: 79 additions & 0 deletions testutil/testconnmanager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package testutil

import (
"sync"

"github.com/libp2p/go-libp2p-core/peer"
"github.com/stretchr/testify/require"
)

// TestConnManager implements network.ConnManager and allows you to assert
// behavior
type TestConnManager struct {
protectedConnsLk sync.RWMutex
protectedConns map[peer.ID][]string
}

// NewTestConnManager returns a new TestConnManager
func NewTestConnManager() *TestConnManager {
return &TestConnManager{protectedConns: make(map[peer.ID][]string)}
}

// Protect simulates protecting a connection (just records occurence)
func (tcm *TestConnManager) Protect(p peer.ID, tag string) {
tcm.protectedConnsLk.Lock()
defer tcm.protectedConnsLk.Unlock()
for _, tagCmp := range tcm.protectedConns[p] {
if tag == tagCmp {
return
}
}
tcm.protectedConns[p] = append(tcm.protectedConns[p], tag)
}

// Unprotect simulates unprotecting a connection (just records occurence)
func (tcm *TestConnManager) Unprotect(p peer.ID, tag string) bool {
tcm.protectedConnsLk.Lock()
defer tcm.protectedConnsLk.Unlock()
for i, tagCmp := range tcm.protectedConns[p] {
if tag == tagCmp {
tcm.protectedConns[p] = append(tcm.protectedConns[p][:i], tcm.protectedConns[p][i+1:]...)
break
}
}
return len(tcm.protectedConns[p]) > 0
}

// AssertProtected asserts that the connection is protected by at least one tag
func (tcm *TestConnManager) AssertProtected(t TestingT, p peer.ID) {
tcm.protectedConnsLk.RLock()
defer tcm.protectedConnsLk.RUnlock()
require.True(t, len(tcm.protectedConns[p]) > 0)
}

// RefuteProtected refutes that a connection has been protect
func (tcm *TestConnManager) RefuteProtected(t TestingT, p peer.ID) {
tcm.protectedConnsLk.RLock()
defer tcm.protectedConnsLk.RUnlock()
require.False(t, len(tcm.protectedConns[p]) > 0)
}

// AssertProtectedWithTags verifies the connection is protected with the given
// tags at least
func (tcm *TestConnManager) AssertProtectedWithTags(t TestingT, p peer.ID, tags ...string) {
tcm.protectedConnsLk.RLock()
defer tcm.protectedConnsLk.RUnlock()
for _, tag := range tags {
require.Contains(t, tcm.protectedConns[p], tag)
}
}

// RefuteProtectedWithTags verifies the connection is not protected with any of the given
// tags
func (tcm *TestConnManager) RefuteProtectedWithTags(t TestingT, p peer.ID, tags ...string) {
tcm.protectedConnsLk.RLock()
defer tcm.protectedConnsLk.RUnlock()
for _, tag := range tags {
require.NotContains(t, tcm.protectedConns[p], tag)
}
}

0 comments on commit 94eda95

Please sign in to comment.