diff --git a/impl/graphsync/graphsync_impl.go b/impl/graphsync/graphsync_impl.go index 4582a5d2..318ae8b4 100644 --- a/impl/graphsync/graphsync_impl.go +++ b/impl/graphsync/graphsync_impl.go @@ -166,6 +166,16 @@ func (impl *graphsyncImpl) OpenPushDataChannel(ctx context.Context, requestTo pe if err != nil { return chid, err } + evt := datatransfer.Event{ + Code: datatransfer.Open, + Message: "New Request Initiated", + Timestamp: time.Now(), + } + chst := impl.channels.GetByIDAndSender(chid, impl.peerID) + err = impl.pubSub.Publish(internalEvent{evt, chst}) + if err != nil { + log.Warnf("err publishing DT event: %s", err.Error()) + } return chid, nil } @@ -183,6 +193,16 @@ func (impl *graphsyncImpl) OpenPullDataChannel(ctx context.Context, requestTo pe if err != nil { return chid, err } + evt := datatransfer.Event{ + Code: datatransfer.Open, + Message: "New Request Initiated", + Timestamp: time.Now(), + } + chst := impl.channels.GetByIDAndSender(chid, requestTo) + err = impl.pubSub.Publish(internalEvent{evt, chst}) + if err != nil { + log.Warnf("err publishing DT event: %s", err.Error()) + } return chid, nil } diff --git a/impl/graphsync/graphsync_impl_test.go b/impl/graphsync/graphsync_impl_test.go index 4e1f9224..bbb7a68c 100644 --- a/impl/graphsync/graphsync_impl_test.go +++ b/impl/graphsync/graphsync_impl_test.go @@ -635,7 +635,7 @@ func TestDataTransferInitiatingPullGraphsyncRequests(t *testing.T) { subscribeCalls := make(chan struct{}, 1) subscribe := func(event datatransfer.Event, channelState datatransfer.ChannelState) { - if event.Code == datatransfer.Error { + if event.Code == datatransfer.Progress { subscribeCalls <- struct{}{} } } @@ -936,6 +936,7 @@ func TestDataTransferPushRoundTrip(t *testing.T) { finished := make(chan struct{}, 2) errChan := make(chan struct{}, 2) + opened := make(chan struct{}, 2) var subscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.Complete { finished <- struct{}{} @@ -943,6 +944,9 @@ func TestDataTransferPushRoundTrip(t *testing.T) { if event.Code == datatransfer.Error { errChan <- struct{}{} } + if event.Code == datatransfer.Open { + opened <- struct{}{} + } } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) @@ -953,11 +957,16 @@ func TestDataTransferPushRoundTrip(t *testing.T) { chid, err := dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) require.NoError(t, err) - for i := 0; i < 2; i++ { + opens := 0 + completes := 0 + for opens < 2 || completes < 2 { select { case <-ctx.Done(): t.Fatal("Did not complete succcessful data transfer") case <-finished: + completes++ + case <-opened: + opens++ case <-errChan: t.Fatal("received error on data transfer") } @@ -984,10 +993,18 @@ func TestDataTransferPullRoundTrip(t *testing.T) { dt2 := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) finished := make(chan struct{}, 2) + errChan := make(chan struct{}, 2) + opened := make(chan struct{}, 2) var subscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { if event.Code == datatransfer.Complete { finished <- struct{}{} } + if event.Code == datatransfer.Error { + errChan <- struct{}{} + } + if event.Code == datatransfer.Open { + opened <- struct{}{} + } } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) @@ -998,11 +1015,18 @@ func TestDataTransferPullRoundTrip(t *testing.T) { _, err := dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) require.NoError(t, err) - for i := 0; i < 2; i++ { + opens := 0 + completes := 0 + for opens < 2 || completes < 2 { select { case <-ctx.Done(): t.Fatal("Did not complete succcessful data transfer") case <-finished: + completes++ + case <-opened: + opens++ + case <-errChan: + t.Fatal("received error on data transfer") } } gsData.VerifyFileTransferred(t, root, true) diff --git a/impl/graphsync/graphsync_receiver.go b/impl/graphsync/graphsync_receiver.go index 449f841c..dd48b411 100644 --- a/impl/graphsync/graphsync_receiver.go +++ b/impl/graphsync/graphsync_receiver.go @@ -43,12 +43,22 @@ func (receiver *graphsyncReceiver) ReceiveRequest( receiver.impl.sendGsRequest(ctx, initiator, incoming.TransferID(), incoming.IsPull(), dataSender, root, stor) } - _, err = receiver.impl.channels.CreateNew(incoming.TransferID(), incoming.BaseCid(), stor, voucher, initiator, dataSender, dataReceiver) + chid, err := receiver.impl.channels.CreateNew(incoming.TransferID(), incoming.BaseCid(), stor, voucher, initiator, dataSender, dataReceiver) if err != nil { log.Error(err) receiver.impl.sendResponse(ctx, false, initiator, incoming.TransferID()) return } + evt := datatransfer.Event{ + Code: datatransfer.Open, + Message: "Incoming request accepted", + Timestamp: time.Now(), + } + chst := receiver.impl.channels.GetByIDAndSender(chid, dataSender) + err = receiver.impl.pubSub.Publish(internalEvent{evt, chst}) + if err != nil { + log.Warnf("err publishing DT event: %s", err.Error()) + } receiver.impl.sendResponse(ctx, true, initiator, incoming.TransferID()) } @@ -109,6 +119,7 @@ func (receiver *graphsyncReceiver) ReceiveResponse( // initiator is us. construct a channel id for a pull request that we initiated and see // if there is one in our saved channel list. otherwise we should not respond. chid := datatransfer.ChannelID{Initiator: receiver.impl.peerID, ID: incoming.TransferID()} + evt.Code = datatransfer.Progress // if we are handling a response to a pull request then they are sending data and the // initiator is us @@ -116,7 +127,6 @@ func (receiver *graphsyncReceiver) ReceiveResponse( baseCid := chst.BaseCID() root := cidlink.Link{Cid: baseCid} receiver.impl.sendGsRequest(ctx, receiver.impl.peerID, incoming.TransferID(), true, sender, root, chst.Selector()) - evt.Code = datatransfer.Progress } } err := receiver.impl.pubSub.Publish(internalEvent{evt, chst})