Skip to content

Commit

Permalink
fix: dont double count data sent
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkmc committed Apr 1, 2021
1 parent de9804f commit 3c0714a
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 56 deletions.
30 changes: 17 additions & 13 deletions channels/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,21 @@ func (c *Channels) CompleteCleanupOnRestart(chid datatransfer.ChannelID) error {
return c.send(chid, datatransfer.CompleteCleanupOnRestart)
}

func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error {
// Returns true if this is the first time the block has been sent
func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) {
return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, k, delta)
}

func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error {
// Returns true if this is the first time the block has been queued
func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) {
return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, k, delta)
}

func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64) error {
// Returns true if this is the first time the block has been received
func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) {
err := c.cidLists.AppendList(chid, k)
if err != nil {
return err
return false, err
}

return c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, k, delta)
Expand Down Expand Up @@ -357,31 +360,32 @@ func (c *Channels) removeSeenCIDCaches(chid datatransfer.ChannelID) error {
return nil
}

// onProgress fires an event indicating progress has been made in
// queuing / sending / receiving blocks.
// These events are fired only for new blocks (not for example if
// a block is resent)
func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64) error {
// fireProgressEvent fires
// - an event for queuing / sending / receiving blocks
// - a corresponding "progress" event if the block has not been seen before
// For example if a block is resent, the method will fire DataSent but not DataSentProgress.
// Returns true if a progress event was fired.
func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64) (bool, error) {
if err := c.checkChannelExists(chid, evt); err != nil {
return err
return false, err
}

// Check if the block has already been seen
sid := cidsets.SetID(chid.String() + "/" + datatransfer.Events[evt])
seen, err := c.seenCIDs.InsertSetCID(sid, k)
if err != nil {
return err
return false, err
}

// If the block has not been seen before, fire the progress event
if !seen {
if err := c.stateMachines.Send(chid, progressEvt, delta); err != nil {
return err
return false, err
}
}

// Fire the regular event
return c.stateMachines.Send(chid, evt)
return !seen, c.stateMachines.Send(chid, evt)
}

func (c *Channels) send(chid datatransfer.ChannelID, code datatransfer.EventCode, args ...interface{}) error {
Expand Down
20 changes: 13 additions & 7 deletions channels/channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,46 +145,52 @@ func TestChannels(t *testing.T) {
require.Equal(t, uint64(0), state.Sent())
require.Empty(t, state.ReceivedCids())

err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50)
var isNew bool
isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50)
require.NoError(t, err)
_ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress)
require.True(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataReceived)
require.Equal(t, uint64(50), state.Received())
require.Equal(t, uint64(0), state.Sent())
require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids())

err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100)
isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100)
require.NoError(t, err)
_ = checkEvent(ctx, t, received, datatransfer.DataSentProgress)
require.True(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataSent)
require.Equal(t, uint64(50), state.Received())
require.Equal(t, uint64(100), state.Sent())
require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids())

// errors if channel does not exist
err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200)
isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200)
require.True(t, xerrors.As(err, new(*channels.ErrNotFound)))
err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200)
isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200)
require.True(t, xerrors.As(err, new(*channels.ErrNotFound)))
require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids())

err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50)
isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50)
require.NoError(t, err)
_ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress)
require.True(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataReceived)
require.Equal(t, uint64(100), state.Received())
require.Equal(t, uint64(100), state.Sent())
require.Equal(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids())

err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25)
isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25)
require.NoError(t, err)
require.False(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataSent)
require.Equal(t, uint64(100), state.Received())
require.Equal(t, uint64(100), state.Sent())
require.Equal(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids())

err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50)
isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50)
require.NoError(t, err)
require.False(t, isNew)
state = checkEvent(ctx, t, received, datatransfer.DataReceived)
require.Equal(t, uint64(100), state.Received())
require.Equal(t, uint64(100), state.Sent())
Expand Down
110 changes: 74 additions & 36 deletions impl/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,63 +28,101 @@ func (m *manager) OnChannelOpened(chid datatransfer.ChannelID) error {
return nil
}

// OnDataReceived is called when the transport layer reports that it has
// received some data from the sender.
// It fires an event on the channel, updating the sum of received data and
// calls revalidators so they can pause / resume the channel or send a
// message over the transport.
func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64) error {
err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size)
isNew, err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size)
if err != nil {
return err
}

if chid.Initiator != m.peerID {
var result datatransfer.VoucherResult
var err error
var handled bool
_ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error {
revalidator := processor.(datatransfer.Revalidator)
handled, result, err = revalidator.OnPushDataReceived(chid, size)
if handled {
return errors.New("stop processing")
}
return nil
})
if err != nil || result != nil {
msg, err := m.processRevalidationResult(chid, result, err)
if msg != nil {
if err := m.dataTransferNetwork.SendMessage(context.TODO(), chid.Initiator, msg); err != nil {
return err
}
// If this block has already been received on the channel, take no further
// action (this can happen when the data-transfer channel is restarted)
if !isNew {
return nil
}

// If this node initiated the data transfer, there's nothing more to do
if chid.Initiator == m.peerID {
return nil
}

// Check each revalidator to see if they want to pause / resume, or send
// a message over the transport
var result datatransfer.VoucherResult
var handled bool
_ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error {
revalidator := processor.(datatransfer.Revalidator)
handled, result, err = revalidator.OnPushDataReceived(chid, size)
if handled {
return errors.New("stop processing")
}
return nil
})
if err != nil || result != nil {
msg, err := m.processRevalidationResult(chid, result, err)
if msg != nil {
if err := m.dataTransferNetwork.SendMessage(context.TODO(), chid.Initiator, msg); err != nil {
return err
}
return err
}
return err
}

return nil
}

// OnDataQueued is called when the transport layer reports that it has queued
// up some data to be sent to the requester.
// It fires an event on the channel, updating the sum of queued data and calls
// revalidators so they can pause / resume or send a message over the transport.
func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64) (datatransfer.Message, error) {
if err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size); err != nil {
// The transport layer reports that some data has been queued up to be sent
// to the requester, so fire a DataQueued event on the channels state
// machine.
isNew, err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size)
if err != nil {
return nil, err
}
if chid.Initiator != m.peerID {
var result datatransfer.VoucherResult
var err error
var handled bool
_ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error {
revalidator := processor.(datatransfer.Revalidator)
handled, result, err = revalidator.OnPullDataSent(chid, size)
if handled {
return errors.New("stop processing")
}
return nil
})
if err != nil || result != nil {
return m.processRevalidationResult(chid, result, err)

// If this block has already been queued on the channel, take no further
// action (this can happen when the data-transfer channel is restarted)
if !isNew {
return nil, nil
}

// If this node initiated the data transfer, there's nothing more to do
if chid.Initiator == m.peerID {
return nil, nil
}

// Check each revalidator to see if they want to pause / resume, or send
// a message over the transport.
// For example if the data-sender is waiting for the receiver to pay for
// data they may pause the data-transfer.
var result datatransfer.VoucherResult
var handled bool
_ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error {
revalidator := processor.(datatransfer.Revalidator)
handled, result, err = revalidator.OnPullDataSent(chid, size)
if handled {
return errors.New("stop processing")
}
return nil
})
if err != nil || result != nil {
return m.processRevalidationResult(chid, result, err)
}

return nil, nil
}

func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64) error {
return m.channels.DataSent(chid, link.(cidlink.Link).Cid, size)
_, err := m.channels.DataSent(chid, link.(cidlink.Link).Cid, size)
return err
}

func (m *manager) OnRequestReceived(chid datatransfer.ChannelID, request datatransfer.Request) (datatransfer.Response, error) {
Expand Down

0 comments on commit 3c0714a

Please sign in to comment.