Skip to content

Commit

Permalink
Protect Libp2p Connections (#229)
Browse files Browse the repository at this point in the history
* feat(requestmanager): add connection protection

* refactor(testutil): extract TestConnManager

* feat(responsemanager): add connection holding

also uncovered a bug in early cancellations, resolved by using state pattern from requestmanager

* refactor(graphsync): change string to unique tag

make tag for request IDs unique to graphsync
  • Loading branch information
hannahhoward authored Sep 29, 2021
1 parent 5c5f1e8 commit ce3951d
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 77 deletions.
5 changes: 5 additions & 0 deletions benchmarks/testnet/virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

delay "github.com/ipfs/go-ipfs-delay"
mockrouting "github.com/ipfs/go-ipfs-routing/mock"
"github.com/libp2p/go-libp2p-core/connmgr"
"github.com/libp2p/go-libp2p-core/peer"
tnet "github.com/libp2p/go-libp2p-testing/net"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
Expand Down Expand Up @@ -255,6 +256,10 @@ func (nc *networkClient) DisconnectFrom(_ context.Context, p peer.ID) error {
return nil
}

func (nc *networkClient) ConnectionManager() gsnet.ConnManager {
return &connmgr.NullConnMgr{}
}

func (rq *receiverQueue) enqueue(m *message) {
rq.lk.Lock()
defer rq.lk.Unlock()
Expand Down
5 changes: 5 additions & 0 deletions graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import (
// RequestID is a unique identifier for a GraphSync request.
type RequestID int32

// Tag returns an easy way to identify this request id as a graphsync request (for libp2p connections)
func (r RequestID) Tag() string {
return fmt.Sprintf("graphsync-request-%d", r)
}

// Priority a priority for a GraphSync request.
type Priority int32

Expand Down
4 changes: 2 additions & 2 deletions impl/graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,11 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork,

asyncLoader := asyncloader.New(ctx, linkSystem, requestAllocator)
requestQueue := taskqueue.NewTaskQueue(ctx)
requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue)
requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue, network.ConnectionManager())
requestExecutor := executor.NewExecutor(requestManager, incomingBlockHooks, asyncLoader.AsyncLoad)
responseAssembler := responseassembler.New(ctx, peerManager)
peerTaskQueue := peertaskqueue.New()
responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests)
responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests, network.ConnectionManager())
graphSync := &GraphSync{
network: network,
linkSystem: linkSystem,
Expand Down
8 changes: 8 additions & 0 deletions network/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ type GraphSyncNetwork interface {
ConnectTo(context.Context, peer.ID) error

NewMessageSender(context.Context, peer.ID) (MessageSender, error)

ConnectionManager() ConnManager
}

// ConnManager provides the methods needed to protect and unprotect connections
type ConnManager interface {
Protect(peer.ID, string)
Unprotect(peer.ID, string) bool
}

// MessageSender is an interface to send messages to a peer
Expand Down
4 changes: 4 additions & 0 deletions network/libp2p_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ func (gsnet *libp2pGraphSyncNetwork) handleNewStream(s network.Stream) {
}
}

func (gsnet *libp2pGraphSyncNetwork) ConnectionManager() ConnManager {
return gsnet.host.ConnManager()
}

type libp2pGraphSyncNotifee libp2pGraphSyncNetwork

func (nn *libp2pGraphSyncNotifee) libp2pGraphSyncNetwork() *libp2pGraphSyncNetwork {
Expand Down
4 changes: 4 additions & 0 deletions requestmanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
gsmsg "github.com/ipfs/go-graphsync/message"
"github.com/ipfs/go-graphsync/messagequeue"
"github.com/ipfs/go-graphsync/metadata"
"github.com/ipfs/go-graphsync/network"
"github.com/ipfs/go-graphsync/notifications"
"github.com/ipfs/go-graphsync/requestmanager/executor"
"github.com/ipfs/go-graphsync/requestmanager/hooks"
Expand Down Expand Up @@ -94,6 +95,7 @@ type RequestManager struct {
asyncLoader AsyncLoader
disconnectNotif *pubsub.PubSub
linkSystem ipld.LinkSystem
connManager network.ConnManager

// dont touch out side of run loop
nextRequestID graphsync.RequestID
Expand Down Expand Up @@ -126,6 +128,7 @@ func New(ctx context.Context,
responseHooks ResponseHooks,
networkErrorListeners *listeners.NetworkErrorListeners,
requestQueue taskqueue.TaskQueue,
connManager network.ConnManager,
) *RequestManager {
ctx, cancel := context.WithCancel(ctx)
return &RequestManager{
Expand All @@ -141,6 +144,7 @@ func New(ctx context.Context,
responseHooks: responseHooks,
networkErrorListeners: networkErrorListeners,
requestQueue: requestQueue,
connManager: connManager,
}
}

Expand Down
150 changes: 87 additions & 63 deletions requestmanager/requestmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,68 +29,6 @@ import (
"github.com/ipfs/go-graphsync/testutil"
)

type requestRecord struct {
gsr gsmsg.GraphSyncRequest
p peer.ID
}

type fakePeerHandler struct {
requestRecordChan chan requestRecord
}

func (fph *fakePeerHandler) AllocateAndBuildMessage(p peer.ID, blkSize uint64,
requestBuilder func(b *gsmsg.Builder), notifees []notifications.Notifee) {
builder := gsmsg.NewBuilder(gsmsg.Topic(0))
requestBuilder(builder)
message, err := builder.Build()
if err != nil {
panic(err)
}
fph.requestRecordChan <- requestRecord{
gsr: message.Requests()[0],
p: p,
}
}

func readNNetworkRequests(ctx context.Context,
t *testing.T,
requestRecordChan <-chan requestRecord,
count int) []requestRecord {
requestRecords := make([]requestRecord, 0, count)
for i := 0; i < count; i++ {
var rr requestRecord
testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i))
requestRecords = append(requestRecords, rr)
}
// because of the simultaneous request queues it's possible for the requests to go to the network layer out of order
// if the requests are queued at a near identical time
sort.Slice(requestRecords, func(i, j int) bool {
return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID()
})
return requestRecords
}

func metadataForBlocks(blks []blocks.Block, present bool) metadata.Metadata {
md := make(metadata.Metadata, 0, len(blks))
for _, block := range blks {
md = append(md, metadata.Item{
Link: block.Cid(),
BlockPresent: present,
})
}
return md
}

func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) graphsync.ExtensionData {
md := metadataForBlocks(blks, present)
metadataEncoded, err := metadata.EncodeMetadata(md)
require.NoError(t, err, "did not encode metadata")
return graphsync.ExtensionData{
Name: graphsync.ExtensionMetadata,
Data: metadataEncoded,
}
}

func TestNormalSimultaneousFetch(t *testing.T) {
ctx := context.Background()
td := newTestData(ctx, t)
Expand All @@ -106,6 +44,8 @@ func TestNormalSimultaneousFetch(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2)

td.tcm.AssertProtected(t, peers[0])
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag(), requestRecords[1].gsr.ID().Tag())
require.Equal(t, peers[0], requestRecords[0].p)
require.Equal(t, peers[0], requestRecords[1].p)
require.False(t, requestRecords[0].gsr.IsCancel())
Expand Down Expand Up @@ -148,6 +88,10 @@ func TestNormalSimultaneousFetch(t *testing.T) {
td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan1)
blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3)

td.tcm.AssertProtected(t, peers[0])
td.tcm.RefuteProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag())
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[1].gsr.ID().Tag())

moreBlocks := blockChain2.RemainderBlocks(3)
moreMetadata := metadataForBlocks(moreBlocks, true)
moreMetadataEncoded, err := metadata.EncodeMetadata(moreMetadata)
Expand All @@ -170,6 +114,8 @@ func TestNormalSimultaneousFetch(t *testing.T) {
blockChain2.VerifyRemainder(requestCtx, returnedResponseChan2, 3)
testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan1)
testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan2)

td.tcm.RefuteProtected(t, peers[0])
}

func TestCancelRequestInProgress(t *testing.T) {
Expand All @@ -187,6 +133,9 @@ func TestCancelRequestInProgress(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2)

td.tcm.AssertProtected(t, peers[0])
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag(), requestRecords[1].gsr.ID().Tag())

firstBlocks := td.blockChain.Blocks(0, 3)
firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true)
firstResponses := []gsmsg.GraphSyncResponse{
Expand Down Expand Up @@ -224,6 +173,8 @@ func TestCancelRequestInProgress(t *testing.T) {
require.Len(t, errors, 1)
_, ok := errors[0].(graphsync.RequestClientCancelledErr)
require.True(t, ok)

td.tcm.RefuteProtected(t, peers[0])
}
func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
ctx := context.Background()
Expand All @@ -246,6 +197,9 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {

requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)

td.tcm.AssertProtected(t, peers[0])
td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag())

go func() {
firstBlocks := td.blockChain.Blocks(0, 3)
firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true)
Expand All @@ -267,6 +221,8 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) {
require.True(t, rr.gsr.IsCancel())
require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID())

td.tcm.RefuteProtected(t, peers[0])

errors := testutil.CollectErrors(requestCtx, t, returnedErrorChan1)
require.Len(t, errors, 1)
_, ok := errors[0].(graphsync.RequestClientCancelledErr)
Expand Down Expand Up @@ -321,13 +277,17 @@ func TestFailedRequest(t *testing.T) {
returnedResponseChan, returnedErrorChan := td.requestManager.NewRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector())

rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0]
td.tcm.AssertProtected(t, peers[0])
td.tcm.AssertProtectedWithTags(t, peers[0], rr.gsr.ID().Tag())

failedResponses := []gsmsg.GraphSyncResponse{
gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestFailedContentNotFound),
}
td.requestManager.ProcessResponses(peers[0], failedResponses, nil)

testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan)
testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan)
td.tcm.RefuteProtected(t, peers[0])
}

func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) {
Expand Down Expand Up @@ -962,10 +922,73 @@ func TestPauseResumeExternal(t *testing.T) {
testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan)
}

type requestRecord struct {
gsr gsmsg.GraphSyncRequest
p peer.ID
}

type fakePeerHandler struct {
requestRecordChan chan requestRecord
}

func (fph *fakePeerHandler) AllocateAndBuildMessage(p peer.ID, blkSize uint64,
requestBuilder func(b *gsmsg.Builder), notifees []notifications.Notifee) {
builder := gsmsg.NewBuilder(gsmsg.Topic(0))
requestBuilder(builder)
message, err := builder.Build()
if err != nil {
panic(err)
}
fph.requestRecordChan <- requestRecord{
gsr: message.Requests()[0],
p: p,
}
}

func readNNetworkRequests(ctx context.Context,
t *testing.T,
requestRecordChan <-chan requestRecord,
count int) []requestRecord {
requestRecords := make([]requestRecord, 0, count)
for i := 0; i < count; i++ {
var rr requestRecord
testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i))
requestRecords = append(requestRecords, rr)
}
// because of the simultaneous request queues it's possible for the requests to go to the network layer out of order
// if the requests are queued at a near identical time
sort.Slice(requestRecords, func(i, j int) bool {
return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID()
})
return requestRecords
}

func metadataForBlocks(blks []blocks.Block, present bool) metadata.Metadata {
md := make(metadata.Metadata, 0, len(blks))
for _, block := range blks {
md = append(md, metadata.Item{
Link: block.Cid(),
BlockPresent: present,
})
}
return md
}

func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) graphsync.ExtensionData {
md := metadataForBlocks(blks, present)
metadataEncoded, err := metadata.EncodeMetadata(md)
require.NoError(t, err, "did not encode metadata")
return graphsync.ExtensionData{
Name: graphsync.ExtensionMetadata,
Data: metadataEncoded,
}
}

type testData struct {
requestRecordChan chan requestRecord
fph *fakePeerHandler
fal *testloader.FakeAsyncLoader
tcm *testutil.TestConnManager
requestHooks *hooks.OutgoingRequestHooks
responseHooks *hooks.IncomingResponseHooks
blockHooks *hooks.IncomingBlockHooks
Expand All @@ -989,13 +1012,14 @@ func newTestData(ctx context.Context, t *testing.T) *testData {
td.requestRecordChan = make(chan requestRecord, 3)
td.fph = &fakePeerHandler{td.requestRecordChan}
td.fal = testloader.NewFakeAsyncLoader()
td.tcm = testutil.NewTestConnManager()
td.requestHooks = hooks.NewRequestHooks()
td.responseHooks = hooks.NewResponseHooks()
td.blockHooks = hooks.NewBlockHooks()
td.networkErrorListeners = listeners.NewNetworkErrorListeners()
td.taskqueue = taskqueue.NewTaskQueue(ctx)
lsys := cidlink.DefaultLinkSystem()
td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue)
td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.tcm)
td.executor = executor.NewExecutor(td.requestManager, td.blockHooks, td.fal.AsyncLoad)
td.requestManager.SetDelegate(td.fph)
td.requestManager.Startup()
Expand Down
2 changes: 2 additions & 0 deletions requestmanager/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ func (rm *RequestManager) newRequest(p peer.ID, root ipld.Link, selector ipld.No
requestStatus.lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged))
rm.inProgressRequestStatuses[request.ID()] = requestStatus

rm.connManager.Protect(p, requestID.Tag())
rm.requestQueue.PushTask(p, peertask.Task{Topic: requestID, Priority: math.MaxInt32, Work: 1})
return request, requestStatus.inProgressChan, requestStatus.inProgressErr
}
Expand Down Expand Up @@ -151,6 +152,7 @@ func (rm *RequestManager) terminateRequest(requestID graphsync.RequestID, ipr *i
case <-rm.ctx.Done():
}
}
rm.connManager.Unprotect(ipr.p, requestID.Tag())
delete(rm.inProgressRequestStatuses, requestID)
ipr.cancelFn()
rm.asyncLoader.CleanupRequest(requestID)
Expand Down
Loading

0 comments on commit ce3951d

Please sign in to comment.