diff --git a/channels/channels.go b/channels/channels.go index 64faff5a..e11ed00b 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -226,18 +226,21 @@ func (c *Channels) CompleteCleanupOnRestart(chid datatransfer.ChannelID) error { return c.send(chid, datatransfer.CompleteCleanupOnRestart) } -func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error { +// Returns true if this is the first time the block has been sent +func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) { return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, k, delta) } -func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error { +// Returns true if this is the first time the block has been queued +func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) { return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, k, delta) } -func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error { +// Returns true if this is the first time the block has been received +func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) { err := c.cidLists.AppendList(chid, k) if err != nil { - return err + return false, err } return c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, k, delta) @@ -357,31 +360,34 @@ func (c *Channels) removeSeenCIDCaches(chid datatransfer.ChannelID) error { return nil } -// onProgress fires an event indicating progress has been made in -// queuing / sending / receiving blocks. -// These events are fired only for new blocks (not for example if -// a block is resent) -func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64) error { +// fireProgressEvent fires +// - an event for queuing / sending / receiving blocks +// - a corresponding "progress" event if the block has not been seen before +// For example, if a block is being sent for the first time, the method will +// fire both DataSent AND DataSentProgress. +// If a block is resent, the method will fire DataSent but not DataSentProgress. +// Returns true if the block is new (both the event and a progress event were fired). +func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64) (bool, error) { if err := c.checkChannelExists(chid, evt); err != nil { - return err + return false, err } // Check if the block has already been seen sid := cidsets.SetID(chid.String() + "/" + datatransfer.Events[evt]) seen, err := c.seenCIDs.InsertSetCID(sid, k) if err != nil { - return err + return false, err } // If the block has not been seen before, fire the progress event if !seen { if err := c.stateMachines.Send(chid, progressEvt, delta); err != nil { - return err + return false, err } } // Fire the regular event - return c.stateMachines.Send(chid, evt) + return !seen, c.stateMachines.Send(chid, evt) } func (c *Channels) send(chid datatransfer.ChannelID, code datatransfer.EventCode, args ...interface{}) error { diff --git a/channels/channels_test.go b/channels/channels_test.go index 133d745b..6ad58fad 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -145,46 +145,53 @@ func TestChannels(t *testing.T) { require.Equal(t, uint64(0), state.Sent()) require.Empty(t, state.ReceivedCids()) - err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50) + isNew, err := channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress) + require.True(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(50), state.Received()) require.Equal(t, uint64(0), state.Sent()) require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids()) - err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100) + isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataSentProgress) + require.True(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataSent) require.Equal(t, uint64(50), state.Received()) require.Equal(t, uint64(100), state.Sent()) require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids()) // errors if channel does not exist - err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200) + isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200) require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) - err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200) + require.False(t, isNew) + isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200) require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids()) + require.False(t, isNew) - err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50) + isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress) + require.True(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) require.Equal(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) - err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25) + isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25) require.NoError(t, err) + require.False(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataSent) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) require.Equal(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) - err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50) + isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50) require.NoError(t, err) + require.False(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) diff --git a/impl/events.go b/impl/events.go index fe49b57d..68e70413 100644 --- a/impl/events.go +++ b/impl/events.go @@ -28,63 +28,101 @@ func (m *manager) OnChannelOpened(chid datatransfer.ChannelID) error { return nil } +// OnDataReceived is called when the transport layer reports that it has +// received some data from the sender. +// It fires an event on the channel, updating the sum of received data and +// calls revalidators so they can pause / resume the channel or send a +// message over the transport. func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64) error { - err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size) + isNew, err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size) if err != nil { return err } - if chid.Initiator != m.peerID { - var result datatransfer.VoucherResult - var err error - var handled bool - _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { - revalidator := processor.(datatransfer.Revalidator) - handled, result, err = revalidator.OnPushDataReceived(chid, size) - if handled { - return errors.New("stop processing") - } - return nil - }) - if err != nil || result != nil { - msg, err := m.processRevalidationResult(chid, result, err) - if msg != nil { - if err := m.dataTransferNetwork.SendMessage(context.TODO(), chid.Initiator, msg); err != nil { - return err - } + // If this block has already been received on the channel, take no further + // action (this can happen when the data-transfer channel is restarted) + if !isNew { + return nil + } + + // If this node initiated the data transfer, there's nothing more to do + if chid.Initiator == m.peerID { + return nil + } + + // Check each revalidator to see if they want to pause / resume, or send + // a message over the transport + var result datatransfer.VoucherResult + var handled bool + _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { + revalidator := processor.(datatransfer.Revalidator) + handled, result, err = revalidator.OnPushDataReceived(chid, size) + if handled { + return errors.New("stop processing") + } + return nil + }) + if err != nil || result != nil { + msg, err := m.processRevalidationResult(chid, result, err) + if msg != nil { + if err := m.dataTransferNetwork.SendMessage(context.TODO(), chid.Initiator, msg); err != nil { + return err } - return err } + return err } + return nil } +// OnDataQueued is called when the transport layer reports that it has queued +// up some data to be sent to the requester. +// It fires an event on the channel, updating the sum of queued data and calls +// revalidators so they can pause / resume or send a message over the transport. func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64) (datatransfer.Message, error) { - if err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size); err != nil { + // The transport layer reports that some data has been queued up to be sent + // to the requester, so fire a DataQueued event on the channels state + // machine. + isNew, err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size) + if err != nil { return nil, err } - if chid.Initiator != m.peerID { - var result datatransfer.VoucherResult - var err error - var handled bool - _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { - revalidator := processor.(datatransfer.Revalidator) - handled, result, err = revalidator.OnPullDataSent(chid, size) - if handled { - return errors.New("stop processing") - } - return nil - }) - if err != nil || result != nil { - return m.processRevalidationResult(chid, result, err) + + // If this block has already been queued on the channel, take no further + // action (this can happen when the data-transfer channel is restarted) + if !isNew { + return nil, nil + } + + // If this node initiated the data transfer, there's nothing more to do + if chid.Initiator == m.peerID { + return nil, nil + } + + // Check each revalidator to see if they want to pause / resume, or send + // a message over the transport. + // For example if the data-sender is waiting for the receiver to pay for + // data they may pause the data-transfer. + var result datatransfer.VoucherResult + var handled bool + _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { + revalidator := processor.(datatransfer.Revalidator) + handled, result, err = revalidator.OnPullDataSent(chid, size) + if handled { + return errors.New("stop processing") } + return nil + }) + if err != nil || result != nil { + return m.processRevalidationResult(chid, result, err) } return nil, nil } func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64) error { - return m.channels.DataSent(chid, link.(cidlink.Link).Cid, size) + _, err := m.channels.DataSent(chid, link.(cidlink.Link).Cid, size) + return err } func (m *manager) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) { diff --git a/impl/integration_test.go b/impl/integration_test.go index 77bd3ade..061596e7 100644 --- a/impl/integration_test.go +++ b/impl/integration_test.go @@ -41,6 +41,7 @@ import ( ) const loremFile = "lorem.txt" +const loremFileTransferBytes = 20439 // nil means use the default protocols // tests data transfer for the following protocol combinations: @@ -520,6 +521,66 @@ func (dc *disconnectCoordinator) onDisconnect() { close(dc.disconnected) } +type restartRevalidator struct { + *testutil.StubbedRevalidator + pullDataSent map[datatransfer.ChannelID][]uint64 + pushDataRcvd map[datatransfer.ChannelID][]uint64 +} + +func newRestartRevalidator() *restartRevalidator { + return &restartRevalidator{ + StubbedRevalidator: testutil.NewStubbedRevalidator(), + pullDataSent: make(map[datatransfer.ChannelID][]uint64), + pushDataRcvd: make(map[datatransfer.ChannelID][]uint64), + } +} + +func (r *restartRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (bool, datatransfer.VoucherResult, error) { + chSent, ok := r.pullDataSent[chid] + if !ok { + chSent = []uint64{} + } + chSent = append(chSent, additionalBytesSent) + r.pullDataSent[chid] = chSent + + return true, nil, nil +} + +func (r *restartRevalidator) pullDataSum(chid datatransfer.ChannelID) uint64 { + pullDataSent, ok := r.pullDataSent[chid] + var total uint64 + if !ok { + return total + } + for _, sent := range pullDataSent { + total += sent + } + return total +} + +func (r *restartRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (bool, datatransfer.VoucherResult, error) { + chRcvd, ok := r.pushDataRcvd[chid] + if !ok { + chRcvd = []uint64{} + } + chRcvd = append(chRcvd, additionalBytesReceived) + r.pushDataRcvd[chid] = chRcvd + + return true, nil, nil +} + +func (r *restartRevalidator) pushDataSum(chid datatransfer.ChannelID) uint64 { + pushDataRcvd, ok := r.pushDataRcvd[chid] + var total uint64 + if !ok { + return total + } + for _, rcvd := range pushDataRcvd { + total += rcvd + } + return total +} + // TestAutoRestart tests that if the connection for a push or pull request // goes down, it will automatically restart (given the right config options) func TestAutoRestart(t *testing.T) { @@ -714,6 +775,10 @@ func TestAutoRestart(t *testing.T) { require.NoError(t, initiator.RegisterVoucherType(&testutil.FakeDTType{}, sv)) require.NoError(t, responder.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + // Register a revalidator that records calls to OnPullDataSent and OnPushDataReceived + srv := newRestartRevalidator() + require.NoError(t, responder.RegisterRevalidator(testutil.NewFakeDTType(), srv)) + // If the test case needs to subscribe to response events, provide // the test case with the responder if tc.registerResponder != nil { @@ -795,6 +860,14 @@ func TestAutoRestart(t *testing.T) { } })() + // Verify that the total amount of data sent / received that was + // reported to the revalidator is correct + if tc.isPush { + require.EqualValues(t, loremFileTransferBytes, srv.pushDataSum(chid)) + } else { + require.EqualValues(t, loremFileTransferBytes, srv.pullDataSum(chid)) + } + // Verify that the file was transferred to the destination node testutil.VerifyHasFile(ctx, t, destDagService, root, origBytes) })