diff --git a/shared_testutil/test_network_types.go b/shared_testutil/test_network_types.go index 02eaf46a..2ba244ea 100644 --- a/shared_testutil/test_network_types.go +++ b/shared_testutil/test_network_types.go @@ -406,7 +406,6 @@ type TestStorageDealStream struct { proposalWriter StorageDealProposalWriter responseReader StorageDealResponseReader responseWriter StorageDealResponseWriter - tags map[string]struct{} CloseCount int CloseError error @@ -431,7 +430,6 @@ func NewTestStorageDealStream(params TestStorageDealStreamParams) *TestStorageDe proposalWriter: TrivialStorageDealProposalWriter, responseReader: TrivialStorageDealResponseReader, responseWriter: TrivialStorageDealResponseWriter, - tags: make(map[string]struct{}), } if params.ProposalReader != nil { stream.proposalReader = params.ProposalReader @@ -477,23 +475,6 @@ func (tsds *TestStorageDealStream) Close() error { return tsds.CloseError } -// TagProtectedConnection preserves this connection as higher priority than others -func (tsds TestStorageDealStream) TagProtectedConnection(identifier string) { - tsds.tags[identifier] = struct{}{} -} - -// UntagProtectedConnection removes the given tag on this connection, increasing -// the likelyhood it will be cleaned up -func (tsds TestStorageDealStream) UntagProtectedConnection(identifier string) { - delete(tsds.tags, identifier) -} - -// AssertConnectionTagged verifies a connection was tagged with the given identifier -func (tsds TestStorageDealStream) AssertConnectionTagged(t *testing.T, identifier string) { - _, ok := tsds.tags[identifier] - require.True(t, ok) -} - // TrivialStorageDealProposalReader succeeds trivially, returning an empty proposal. func TrivialStorageDealProposalReader() (smnet.Proposal, error) { return smnet.Proposal{}, nil @@ -559,3 +540,22 @@ func (tpr TestPeerResolver) GetPeers(cid.Cid) ([]rm.RetrievalPeer, error) { } var _ rm.PeerResolver = &TestPeerResolver{} + +type TestPeerTagger struct { + TagCalls []peer.ID + UntagCalls []peer.ID +} + +func NewTestPeerTagger() *TestPeerTagger { + return &TestPeerTagger{} +} + +func (pt *TestPeerTagger) TagPeer(id peer.ID, _ string) { + pt.TagCalls = append(pt.TagCalls, id) +} + +func (pt *TestPeerTagger) UntagPeer(id peer.ID, _ string) { + pt.UntagCalls = append(pt.UntagCalls, id) +} + +var _ smnet.PeerTagger = &TestPeerTagger{} diff --git a/storagemarket/impl/client.go b/storagemarket/impl/client.go index a5306e35..0bf624cd 100644 --- a/storagemarket/impl/client.go +++ b/storagemarket/impl/client.go @@ -610,6 +610,14 @@ func (csg *clientStoreGetter) Get(proposalCid cid.Cid) (*multistore.Store, error return csg.c.multiStore.Get(*deal.StoreID) } +func (c *clientDealEnvironment) TagPeer(peer peer.ID, tag string) { + c.c.net.TagPeer(peer, tag) +} + +func (c *clientDealEnvironment) UntagPeer(peer peer.ID, tag string) { + c.c.net.UntagPeer(peer, tag) +} + // ClientFSMParameterSpec is a valid set of parameters for a client deal FSM - used in doc generation var ClientFSMParameterSpec = fsm.Parameters{ Environment: &clientDealEnvironment{}, diff --git a/storagemarket/impl/clientstates/client_states.go b/storagemarket/impl/clientstates/client_states.go index da070124..02f34288 100644 --- a/storagemarket/impl/clientstates/client_states.go +++ b/storagemarket/impl/clientstates/client_states.go @@ -34,6 +34,7 @@ type ClientDealEnvironment interface { GetProviderDealState(ctx context.Context, proposalCid cid.Cid) (*storagemarket.ProviderDealState, error) PollingInterval() time.Duration DealFunds() funds.DealFunds + network.PeerTagger } // ClientStateEntryFunc is the type for all state entry functions on a storage client @@ -103,6 +104,8 @@ func ProposeDeal(ctx fsm.Context, environment ClientDealEnvironment, deal storag return ctx.Trigger(storagemarket.ClientEventWriteProposalFailed, err) } + environment.TagPeer(deal.Miner, deal.ProposalCid.String()) + if err := s.WriteDealProposal(proposal); err != nil { return ctx.Trigger(storagemarket.ClientEventWriteProposalFailed, err) } @@ -160,6 +163,7 @@ func InitiateDataTransfer(ctx fsm.Context, environment ClientDealEnvironment, de // CheckForDealAcceptance is run until the deal is sealed and published by the provider, or errors func CheckForDealAcceptance(ctx fsm.Context, environment ClientDealEnvironment, deal storagemarket.ClientDeal) error { + dealState, err := environment.GetProviderDealState(ctx.Context(), deal.ProposalCid) if err != nil { log.Warnf("error when querying provider deal state: %w", err) // TODO: at what point do we fail the deal? @@ -214,6 +218,9 @@ func ValidateDealPublished(ctx fsm.Context, environment ClientDealEnvironment, d _ = ctx.Trigger(storagemarket.ClientEventFundsReleased, deal.FundsReserved) } + // at this point data transfer is complete, so unprotect peer connection + environment.UntagPeer(deal.Miner, deal.ProposalCid.String()) + return ctx.Trigger(storagemarket.ClientEventDealPublished, dealID) } @@ -277,6 +284,8 @@ func FailDeal(ctx fsm.Context, environment ClientDealEnvironment, deal storagema // TODO: store in some sort of audit log log.Errorf("deal %s failed: %s", deal.ProposalCid, deal.Message) + environment.UntagPeer(deal.Miner, deal.ProposalCid.String()) + return ctx.Trigger(storagemarket.ClientEventFailed) } diff --git a/storagemarket/impl/clientstates/client_states_test.go b/storagemarket/impl/clientstates/client_states_test.go index 6fd035a9..4c69f67d 100644 --- a/storagemarket/impl/clientstates/client_states_test.go +++ b/storagemarket/impl/clientstates/client_states_test.go @@ -102,7 +102,7 @@ func TestWaitForFunding(t *testing.T) { } func TestProposeDeal(t *testing.T) { - t.Run("succeeds and closes stream", func(t *testing.T) { + t.Run("succeeds, closes stream, and tags connection", func(t *testing.T) { ds := tut.NewTestStorageDealStream(tut.TestStorageDealStreamParams{ ResponseReader: testResponseReader(t, responseParams{ state: storagemarket.StorageDealWaitingForData, @@ -115,6 +115,8 @@ func TestProposeDeal(t *testing.T) { inspector: func(deal storagemarket.ClientDeal, env *fakeEnvironment) { tut.AssertDealState(t, storagemarket.StorageDealStartDataTransfer, deal.State) assert.Equal(t, 1, env.dealStream.CloseCount) + assert.Len(t, env.peerTagger.TagCalls, 1) + assert.Equal(t, deal.Miner, env.peerTagger.TagCalls[0]) }, }) }) @@ -141,7 +143,6 @@ func TestProposeDeal(t *testing.T) { }, }) }) - t.Run("write proposal fails fails", func(t *testing.T) { ds := tut.NewTestStorageDealStream(tut.TestStorageDealStreamParams{ ProposalWriter: tut.FailStorageProposalWriter, @@ -369,6 +370,8 @@ func TestValidateDealPublished(t *testing.T) { assert.Equal(t, abi.DealID(5), deal.DealID) assert.Equal(t, env.dealFunds.ReleaseCalls[0], deal.Proposal.ClientBalanceRequirement()) assert.True(t, deal.FundsReserved.Nil() || deal.FundsReserved.IsZero()) + assert.Len(t, env.peerTagger.UntagCalls, 1) + assert.Equal(t, deal.Miner, env.peerTagger.UntagCalls[0]) }, }) }) @@ -379,6 +382,8 @@ func TestValidateDealPublished(t *testing.T) { tut.AssertDealState(t, storagemarket.StorageDealSealing, deal.State) assert.Equal(t, abi.DealID(5), deal.DealID) assert.Len(t, env.dealFunds.ReleaseCalls, 0) + assert.Len(t, env.peerTagger.UntagCalls, 1) + assert.Equal(t, deal.Miner, env.peerTagger.UntagCalls[0]) }, }) }) @@ -546,6 +551,7 @@ func makeExecutor(ctx context.Context, getDealStatusErr: envParams.getDealStatusErr, pollingInterval: envParams.pollingInterval, dealFunds: tut.NewTestDealFunds(), + peerTagger: tut.NewTestPeerTagger(), } if environment.pollingInterval == 0 { @@ -617,6 +623,7 @@ type fakeEnvironment struct { getDealStatusErr error pollingInterval time.Duration dealFunds *tut.TestDealFunds + peerTagger *tut.TestPeerTagger } type dataTransferParams struct { @@ -663,6 +670,14 @@ func (fe *fakeEnvironment) DealFunds() funds.DealFunds { return fe.dealFunds } +func (fe *fakeEnvironment) TagPeer(id peer.ID, ident string) { + fe.peerTagger.TagPeer(id, ident) +} + +func (fe *fakeEnvironment) UntagPeer(id peer.ID, ident string) { + fe.peerTagger.UntagPeer(id, ident) +} + var _ clientstates.ClientDealEnvironment = &fakeEnvironment{} type responseParams struct { diff --git a/storagemarket/impl/provider.go b/storagemarket/impl/provider.go index c8cfafad..592e2d08 100644 --- a/storagemarket/impl/provider.go +++ b/storagemarket/impl/provider.go @@ -9,6 +9,7 @@ import ( "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" "github.com/ipld/go-ipld-prime" + "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/xerrors" "github.com/filecoin-project/go-address" @@ -668,6 +669,14 @@ func (p *providerDealEnvironment) DealFunds() funds.DealFunds { return p.p.dealFunds } +func (p *providerDealEnvironment) TagPeer(id peer.ID, s string) { + p.p.net.TagPeer(id, s) +} + +func (p *providerDealEnvironment) UntagPeer(id peer.ID, s string) { + p.p.net.UntagPeer(id, s) +} + var _ providerstates.ProviderDealEnvironment = &providerDealEnvironment{} type providerStoreGetter struct { diff --git a/storagemarket/impl/providerstates/provider_states.go b/storagemarket/impl/providerstates/provider_states.go index f2474edf..233b519e 100644 --- a/storagemarket/impl/providerstates/provider_states.go +++ b/storagemarket/impl/providerstates/provider_states.go @@ -44,6 +44,7 @@ type ProviderDealEnvironment interface { PieceStore() piecestore.PieceStore RunCustomDecisionLogic(context.Context, storagemarket.MinerDeal) (bool, string, error) DealFunds() funds.DealFunds + network.PeerTagger } // ProviderStateEntryFunc is the signature for a StateEntryFunc in the provider FSM @@ -51,6 +52,8 @@ type ProviderStateEntryFunc func(ctx fsm.Context, environment ProviderDealEnviro // ValidateDealProposal validates a proposed deal against the provider criteria func ValidateDealProposal(ctx fsm.Context, environment ProviderDealEnvironment, deal storagemarket.MinerDeal) error { + environment.TagPeer(deal.Client, deal.ProposalCid.String()) + tok, _, err := environment.Node().GetChainHead(ctx.Context()) if err != nil { return ctx.Trigger(storagemarket.ProviderEventDealRejected, xerrors.Errorf("node error getting most recent state id: %w", err)) @@ -385,6 +388,9 @@ func VerifyDealActivated(ctx fsm.Context, environment ProviderDealEnvironment, d // WaitForDealCompletion waits for the deal to be slashed or to expire func WaitForDealCompletion(ctx fsm.Context, environment ProviderDealEnvironment, deal storagemarket.MinerDeal) error { + // At this point we have all the data so we can unprotect the connection + environment.UntagPeer(deal.Client, deal.ProposalCid.String()) + node := environment.Node() // Called when the deal expires @@ -433,9 +439,10 @@ func RejectDeal(ctx fsm.Context, environment ProviderDealEnvironment, deal stora // FailDeal cleans up before terminating a deal func FailDeal(ctx fsm.Context, environment ProviderDealEnvironment, deal storagemarket.MinerDeal) error { - log.Warnf("deal %s failed: %s", deal.ProposalCid, deal.Message) + environment.UntagPeer(deal.Client, deal.ProposalCid.String()) + if deal.PiecePath != filestore.Path("") { err := environment.FileStore().Delete(deal.PiecePath) if err != nil { diff --git a/storagemarket/impl/providerstates/provider_states_test.go b/storagemarket/impl/providerstates/provider_states_test.go index a424254b..b6f11341 100644 --- a/storagemarket/impl/providerstates/provider_states_test.go +++ b/storagemarket/impl/providerstates/provider_states_test.go @@ -10,6 +10,7 @@ import ( "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" + "github.com/libp2p/go-libp2p-core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" @@ -56,6 +57,8 @@ func TestValidateDealProposal(t *testing.T) { "succeeds": { dealInspector: func(t *testing.T, deal storagemarket.MinerDeal, env *fakeEnvironment) { tut.AssertDealState(t, storagemarket.StorageDealAcceptWait, deal.State) + require.Len(t, env.peerTagger.TagCalls, 1) + require.Equal(t, deal.Client, env.peerTagger.TagCalls[0]) }, }, "verify signature fails": { @@ -255,6 +258,7 @@ func TestVerifyData(t *testing.T) { tut.AssertDealState(t, storagemarket.StorageDealEnsureProviderFunds, deal.State) require.Equal(t, expPath, deal.PiecePath) require.Equal(t, expMetaPath, deal.MetadataPath) + }, }, "generate piece CID fails": { @@ -741,6 +745,8 @@ func TestWaitForDealCompletion(t *testing.T) { dealInspector: func(t *testing.T, deal storagemarket.MinerDeal, env *fakeEnvironment) { tut.AssertDealState(t, storagemarket.StorageDealSlashed, deal.State) require.Equal(t, abi.ChainEpoch(5), deal.SlashEpoch) + require.Len(t, env.peerTagger.UntagCalls, 1) + require.Equal(t, deal.Client, env.peerTagger.UntagCalls[0]) }, }, "expiration succeeds": { @@ -748,6 +754,8 @@ func TestWaitForDealCompletion(t *testing.T) { nodeParams: nodeParams{OnDealSlashedEpoch: abi.ChainEpoch(0)}, dealInspector: func(t *testing.T, deal storagemarket.MinerDeal, env *fakeEnvironment) { tut.AssertDealState(t, storagemarket.StorageDealExpired, deal.State) + require.Len(t, env.peerTagger.UntagCalls, 1) + require.Equal(t, deal.Client, env.peerTagger.UntagCalls[0]) }, }, "slashing fails": { @@ -1131,6 +1139,7 @@ func makeExecutor(ctx context.Context, fs: fs, pieceStore: pieceStore, dealFunds: tut.NewTestDealFunds(), + peerTagger: tut.NewTestPeerTagger(), } if environment.pieceCid == cid.Undef { environment.pieceCid = defaultPieceCid @@ -1181,6 +1190,7 @@ type fakeEnvironment struct { expectedTags map[string]struct{} receivedTags map[string]struct{} dealFunds *tut.TestDealFunds + peerTagger *tut.TestPeerTagger } func (fe *fakeEnvironment) Address() address.Address { @@ -1231,3 +1241,13 @@ func (fe *fakeEnvironment) RunCustomDecisionLogic(context.Context, storagemarket func (fe *fakeEnvironment) DealFunds() funds.DealFunds { return fe.dealFunds } + +func (fe *fakeEnvironment) TagPeer(id peer.ID, s string) { + fe.peerTagger.TagPeer(id, s) +} + +func (fe *fakeEnvironment) UntagPeer(id peer.ID, s string) { + fe.peerTagger.UntagPeer(id, s) +} + +var _ providerstates.ProviderDealEnvironment = &fakeEnvironment{} diff --git a/storagemarket/network/deal_stream.go b/storagemarket/network/deal_stream.go index ddba8fff..efabeb9b 100644 --- a/storagemarket/network/deal_stream.go +++ b/storagemarket/network/deal_stream.go @@ -56,11 +56,3 @@ func (d *dealStream) Close() error { func (d *dealStream) RemotePeer() peer.ID { return d.p } - -func (d *dealStream) TagProtectedConnection(identifier string) { - d.host.ConnManager().TagPeer(d.p, identifier, TagPriority) -} - -func (d *dealStream) UntagProtectedConnection(identifier string) { - d.host.ConnManager().UntagPeer(d.p, identifier) -} diff --git a/storagemarket/network/libp2p_impl.go b/storagemarket/network/libp2p_impl.go index 8bb5b1fa..f30fdb44 100644 --- a/storagemarket/network/libp2p_impl.go +++ b/storagemarket/network/libp2p_impl.go @@ -114,3 +114,11 @@ func (impl *libp2pStorageMarketNetwork) ID() peer.ID { func (impl *libp2pStorageMarketNetwork) AddAddrs(p peer.ID, addrs []ma.Multiaddr) { impl.host.Peerstore().AddAddrs(p, addrs, 8*time.Hour) } + +func (impl *libp2pStorageMarketNetwork) TagPeer(p peer.ID, id string) { + impl.host.ConnManager().TagPeer(p, id, TagPriority) +} + +func (impl *libp2pStorageMarketNetwork) UntagPeer(p peer.ID, id string) { + impl.host.ConnManager().UntagPeer(p, id) +} diff --git a/storagemarket/network/network.go b/storagemarket/network/network.go index 384e0680..8b8b94aa 100644 --- a/storagemarket/network/network.go +++ b/storagemarket/network/network.go @@ -28,8 +28,6 @@ type StorageDealStream interface { ReadDealResponse() (SignedResponse, error) WriteDealResponse(SignedResponse) error RemotePeer() peer.ID - TagProtectedConnection(identifier string) - UntagProtectedConnection(identifier string) Close() error } @@ -60,4 +58,12 @@ type StorageMarketNetwork interface { StopHandlingRequests() error ID() peer.ID AddAddrs(peer.ID, []ma.Multiaddr) + + PeerTagger +} + +// PeerTagger implements arbitrary tagging of peers +type PeerTagger interface { + TagPeer(peer.ID, string) + UntagPeer(peer.ID, string) }