Skip to content

Commit

Permalink
feat(requestmanager): add connection protection
Browse files Browse the repository at this point in the history
  • Loading branch information
hannahhoward committed Sep 29, 2021
1 parent 5c5f1e8 commit cae08eb
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 64 deletions.
5 changes: 5 additions & 0 deletions benchmarks/testnet/virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

delay "github.com/ipfs/go-ipfs-delay"
mockrouting "github.com/ipfs/go-ipfs-routing/mock"
"github.com/libp2p/go-libp2p-core/connmgr"
"github.com/libp2p/go-libp2p-core/peer"
tnet "github.com/libp2p/go-libp2p-testing/net"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
Expand Down Expand Up @@ -255,6 +256,10 @@ func (nc *networkClient) DisconnectFrom(_ context.Context, p peer.ID) error {
return nil
}

func (nc *networkClient) ConnectionManager() gsnet.ConnManager {
return &connmgr.NullConnMgr{}
}

func (rq *receiverQueue) enqueue(m *message) {
rq.lk.Lock()
defer rq.lk.Unlock()
Expand Down
5 changes: 5 additions & 0 deletions graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"strconv"

"github.com/ipfs/go-cid"
"github.com/ipld/go-ipld-prime"
Expand All @@ -14,6 +15,10 @@ import (
// RequestID is a unique identifier for a GraphSync request.
type RequestID int32

func (r RequestID) String() string {
return strconv.Itoa(int(r))
}

// Priority a priority for a GraphSync request.
type Priority int32

Expand Down
2 changes: 1 addition & 1 deletion impl/graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork,

asyncLoader := asyncloader.New(ctx, linkSystem, requestAllocator)
requestQueue := taskqueue.NewTaskQueue(ctx)
requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue)
requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue, network.ConnectionManager())
requestExecutor := executor.NewExecutor(requestManager, incomingBlockHooks, asyncLoader.AsyncLoad)
responseAssembler := responseassembler.New(ctx, peerManager)
peerTaskQueue := peertaskqueue.New()
Expand Down
8 changes: 8 additions & 0 deletions network/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ type GraphSyncNetwork interface {
ConnectTo(context.Context, peer.ID) error

NewMessageSender(context.Context, peer.ID) (MessageSender, error)

ConnectionManager() ConnManager
}

// ConnManager provides the methods needed to protect and unprotect connections
type ConnManager interface {
Protect(peer.ID, string)
Unprotect(peer.ID, string) bool
}

// MessageSender is an interface to send messages to a peer
Expand Down
4 changes: 4 additions & 0 deletions network/libp2p_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ func (gsnet *libp2pGraphSyncNetwork) handleNewStream(s network.Stream) {
}
}

func (gsnet *libp2pGraphSyncNetwork) ConnectionManager() ConnManager {
return gsnet.host.ConnManager()
}

type libp2pGraphSyncNotifee libp2pGraphSyncNetwork

func (nn *libp2pGraphSyncNotifee) libp2pGraphSyncNetwork() *libp2pGraphSyncNetwork {
Expand Down
4 changes: 4 additions & 0 deletions requestmanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
gsmsg "github.com/ipfs/go-graphsync/message"
"github.com/ipfs/go-graphsync/messagequeue"
"github.com/ipfs/go-graphsync/metadata"
"github.com/ipfs/go-graphsync/network"
"github.com/ipfs/go-graphsync/notifications"
"github.com/ipfs/go-graphsync/requestmanager/executor"
"github.com/ipfs/go-graphsync/requestmanager/hooks"
Expand Down Expand Up @@ -94,6 +95,7 @@ type RequestManager struct {
asyncLoader AsyncLoader
disconnectNotif *pubsub.PubSub
linkSystem ipld.LinkSystem
connManager network.ConnManager

// dont touch out side of run loop
nextRequestID graphsync.RequestID
Expand Down Expand Up @@ -126,6 +128,7 @@ func New(ctx context.Context,
responseHooks ResponseHooks,
networkErrorListeners *listeners.NetworkErrorListeners,
requestQueue taskqueue.TaskQueue,
connManager network.ConnManager,
) *RequestManager {
ctx, cancel := context.WithCancel(ctx)
return &RequestManager{
Expand All @@ -141,6 +144,7 @@ func New(ctx context.Context,
responseHooks: responseHooks,
networkErrorListeners: networkErrorListeners,
requestQueue: requestQueue,
connManager: connManager,
}
}

Expand Down
193 changes: 130 additions & 63 deletions requestmanager/requestmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"sort"
"sync"
"testing"
"time"

Expand All @@ -29,68 +30,6 @@ import (
"github.com/ipfs/go-graphsync/testutil"
)

type requestRecord struct {
gsr gsmsg.GraphSyncRequest
p peer.ID
}

type fakePeerHandler struct {
requestRecordChan chan requestRecord
}

func (fph *fakePeerHandler) AllocateAndBuildMessage(p peer.ID, blkSize uint64,
requestBuilder func(b *gsmsg.Builder), notifees []notifications.Notifee) {
builder := gsmsg.NewBuilder(gsmsg.Topic(0))
requestBuilder(builder)
message, err := builder.Build()
if err != nil {
panic(err)
}
fph.requestRecordChan <- requestRecord{
gsr: message.Requests()[0],
p: p,
}
}

func readNNetworkRequests(ctx context.Context,
t *testing.T,
requestRecordChan <-chan requestRecord,
count int) []requestRecord {
requestRecords := make([]requestRecord, 0, count)
for i := 0; i < count; i++ {
var rr requestRecord
testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i))
requestRecords = append(requestRecords, rr)
}
// because of the simultaneous request queues it's possible for the requests to go to the network layer out of order
// if the requests are queued at a near identical time
sort.Slice(requestRecords, func(i, j int) bool {
return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID()
})
return requestRecords
}

func metadataForBlocks(blks []blocks.Block, present bool) metadata.Metadata {
md := make(metadata.Metadata, 0, len(blks))
for _, block := range blks {
md = append(md, metadata.Item{
Link: block.Cid(),
BlockPresent: present,
})
}
return md
}

func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) graphsync.ExtensionData {
md := metadataForBlocks(blks, present)
metadataEncoded, err := metadata.EncodeMetadata(md)
require.NoError(t, err, "did not encode metadata")
return graphsync.ExtensionData{
Name: graphsync.ExtensionMetadata,
Data: metadataEncoded,
}
}

func TestNormalSimultaneousFetch(t *testing.T) {
ctx := context.Background()
td := newTestData(ctx, t)
Expand All @@ -106,6 +45,9 @@ func TestNormalSimultaneousFetch(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2)

require.True(t, td.fcm.IsProtected(peers[0]))
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String())
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String())
require.Equal(t, peers[0], requestRecords[0].p)
require.Equal(t, peers[0], requestRecords[1].p)
require.False(t, requestRecords[0].gsr.IsCancel())
Expand Down Expand Up @@ -148,6 +90,10 @@ func TestNormalSimultaneousFetch(t *testing.T) {
td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan1)
blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3)

require.True(t, td.fcm.IsProtected(peers[0]))
require.NotContains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String())
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String())

moreBlocks := blockChain2.RemainderBlocks(3)
moreMetadata := metadataForBlocks(moreBlocks, true)
moreMetadataEncoded, err := metadata.EncodeMetadata(moreMetadata)
Expand All @@ -170,6 +116,8 @@ func TestNormalSimultaneousFetch(t *testing.T) {
blockChain2.VerifyRemainder(requestCtx, returnedResponseChan2, 3)
testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan1)
testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan2)

require.False(t, td.fcm.IsProtected(peers[0]))
}

func TestCancelRequestInProgress(t *testing.T) {
Expand All @@ -187,6 +135,10 @@ func TestCancelRequestInProgress(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2)

require.True(t, td.fcm.IsProtected(peers[0]))
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String())
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String())

firstBlocks := td.blockChain.Blocks(0, 3)
firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true)
firstResponses := []gsmsg.GraphSyncResponse{
Expand Down Expand Up @@ -224,6 +176,8 @@ func TestCancelRequestInProgress(t *testing.T) {
require.Len(t, errors, 1)
_, ok := errors[0].(graphsync.RequestClientCancelledErr)
require.True(t, ok)

require.False(t, td.fcm.IsProtected(peers[0]))
}
func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
ctx := context.Background()
Expand All @@ -246,6 +200,9 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)

require.True(t, td.fcm.IsProtected(peers[0]))
require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String())

go func() {
firstBlocks := td.blockChain.Blocks(0, 3)
firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true)
Expand All @@ -267,6 +224,8 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
require.True(t, rr.gsr.IsCancel())
require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID())

require.False(t, td.fcm.IsProtected(peers[0]))

errors := testutil.CollectErrors(requestCtx, t, returnedErrorChan1)
require.Len(t, errors, 1)
_, ok := errors[0].(graphsync.RequestClientCancelledErr)
Expand Down Expand Up @@ -321,13 +280,17 @@ func TestFailedRequest(t *testing.T) {
returnedResponseChan, returnedErrorChan := td.requestManager.NewRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector())

rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0]
require.True(t, td.fcm.IsProtected(peers[0]))
require.Contains(t, td.fcm.Protections(peers[0]), rr.gsr.ID().String())

failedResponses := []gsmsg.GraphSyncResponse{
gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestFailedContentNotFound),
}
td.requestManager.ProcessResponses(peers[0], failedResponses, nil)

testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan)
testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan)
require.False(t, td.fcm.IsProtected(peers[0]))
}

func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) {
Expand Down Expand Up @@ -962,10 +925,113 @@ func TestPauseResumeExternal(t *testing.T) {
testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan)
}

type requestRecord struct {
gsr gsmsg.GraphSyncRequest
p peer.ID
}

type fakePeerHandler struct {
requestRecordChan chan requestRecord
}

func (fph *fakePeerHandler) AllocateAndBuildMessage(p peer.ID, blkSize uint64,
requestBuilder func(b *gsmsg.Builder), notifees []notifications.Notifee) {
builder := gsmsg.NewBuilder(gsmsg.Topic(0))
requestBuilder(builder)
message, err := builder.Build()
if err != nil {
panic(err)
}
fph.requestRecordChan <- requestRecord{
gsr: message.Requests()[0],
p: p,
}
}

func readNNetworkRequests(ctx context.Context,
t *testing.T,
requestRecordChan <-chan requestRecord,
count int) []requestRecord {
requestRecords := make([]requestRecord, 0, count)
for i := 0; i < count; i++ {
var rr requestRecord
testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i))
requestRecords = append(requestRecords, rr)
}
// because of the simultaneous request queues it's possible for the requests to go to the network layer out of order
// if the requests are queued at a near identical time
sort.Slice(requestRecords, func(i, j int) bool {
return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID()
})
return requestRecords
}

func metadataForBlocks(blks []blocks.Block, present bool) metadata.Metadata {
md := make(metadata.Metadata, 0, len(blks))
for _, block := range blks {
md = append(md, metadata.Item{
Link: block.Cid(),
BlockPresent: present,
})
}
return md
}

func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) graphsync.ExtensionData {
md := metadataForBlocks(blks, present)
metadataEncoded, err := metadata.EncodeMetadata(md)
require.NoError(t, err, "did not encode metadata")
return graphsync.ExtensionData{
Name: graphsync.ExtensionMetadata,
Data: metadataEncoded,
}
}

type fakeConnManager struct {
protectedConnsLk sync.RWMutex
protectedConns map[peer.ID][]string
}

func (fcm *fakeConnManager) Protect(p peer.ID, tag string) {
fcm.protectedConnsLk.Lock()
defer fcm.protectedConnsLk.Unlock()
for _, tagCmp := range fcm.protectedConns[p] {
if tag == tagCmp {
return
}
}
fcm.protectedConns[p] = append(fcm.protectedConns[p], tag)
}

func (fcm *fakeConnManager) Unprotect(p peer.ID, tag string) bool {
fcm.protectedConnsLk.Lock()
defer fcm.protectedConnsLk.Unlock()
for i, tagCmp := range fcm.protectedConns[p] {
if tag == tagCmp {
fcm.protectedConns[p] = append(fcm.protectedConns[p][:i], fcm.protectedConns[p][i+1:]...)
break
}
}
return len(fcm.protectedConns[p]) > 0
}

func (fcm *fakeConnManager) IsProtected(p peer.ID) bool {
fcm.protectedConnsLk.RLock()
defer fcm.protectedConnsLk.RUnlock()
return len(fcm.protectedConns[p]) > 0
}

func (fcm *fakeConnManager) Protections(p peer.ID) []string {
fcm.protectedConnsLk.RLock()
defer fcm.protectedConnsLk.RUnlock()
return fcm.protectedConns[p]
}

type testData struct {
requestRecordChan chan requestRecord
fph *fakePeerHandler
fal *testloader.FakeAsyncLoader
fcm *fakeConnManager
requestHooks *hooks.OutgoingRequestHooks
responseHooks *hooks.IncomingResponseHooks
blockHooks *hooks.IncomingBlockHooks
Expand All @@ -989,13 +1055,14 @@ func newTestData(ctx context.Context, t *testing.T) *testData {
td.requestRecordChan = make(chan requestRecord, 3)
td.fph = &fakePeerHandler{td.requestRecordChan}
td.fal = testloader.NewFakeAsyncLoader()
td.fcm = &fakeConnManager{protectedConns: make(map[peer.ID][]string)}
td.requestHooks = hooks.NewRequestHooks()
td.responseHooks = hooks.NewResponseHooks()
td.blockHooks = hooks.NewBlockHooks()
td.networkErrorListeners = listeners.NewNetworkErrorListeners()
td.taskqueue = taskqueue.NewTaskQueue(ctx)
lsys := cidlink.DefaultLinkSystem()
td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue)
td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.fcm)
td.executor = executor.NewExecutor(td.requestManager, td.blockHooks, td.fal.AsyncLoad)
td.requestManager.SetDelegate(td.fph)
td.requestManager.Startup()
Expand Down
Loading

0 comments on commit cae08eb

Please sign in to comment.