diff --git a/errors.go b/errors.go index 38c77f46..cd2e1cf1 100644 --- a/errors.go +++ b/errors.go @@ -24,3 +24,12 @@ const ErrPause = errorType("pause channel") // ErrResume is a special error that the RequestReceived / ResponseReceived hooks can // use to resume the channel const ErrResume = errorType("resume channel") + +// ErrIncomplete indicates a channel did not finish transferring data successfully +const ErrIncomplete = errorType("incomplete response") + +// ErrRejected indicates a request was not accepted +const ErrRejected = errorType("response rejected") + +// ErrUnsupported indicates an operation is not supported by the transport protocol +const ErrUnsupported = errorType("unsupported") diff --git a/impl/events.go b/impl/events.go index e1479571..3642a0b4 100644 --- a/impl/events.go +++ b/impl/events.go @@ -32,10 +32,14 @@ func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, si if chid.Initiator != m.peerID { var result datatransfer.VoucherResult var err error + var handled bool _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { revalidator := processor.(datatransfer.Revalidator) - result, err = revalidator.OnPushDataReceived(chid, size) - return err + handled, result, err = revalidator.OnPushDataReceived(chid, size) + if handled { + return errors.New("stop processing") + } + return nil }) if err != nil || result != nil { msg, err := m.processRevalidationResult(chid, result, err) @@ -58,10 +62,14 @@ func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size u if chid.Initiator != m.peerID { var result datatransfer.VoucherResult var err error + var handled bool _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { revalidator := processor.(datatransfer.Revalidator) - result, err = revalidator.OnPullDataSent(chid, size) - return err + handled, result, err = revalidator.OnPullDataSent(chid, size) + if handled { + return errors.New("stop processing") + } + return nil }) if err != nil || result != nil { return m.processRevalidationResult(chid, result, err) @@ -115,7 +123,7 @@ func (m *manager) OnResponseReceived(chid datatransfer.ChannelID, response datat } } if !response.Accepted() { - return m.channels.Error(chid, errors.New("Response Rejected")) + return m.channels.Error(chid, datatransfer.ErrRejected) } if response.IsNew() { err := m.channels.Accept(chid) @@ -162,7 +170,15 @@ func (m *manager) OnChannelCompleted(chid datatransfer.ChannelID, success bool) } return m.channels.FinishTransfer(chid) } - return m.channels.Error(chid, errors.New("incomplete response")) + chst, err := m.channels.GetByID(context.TODO(), chid) + if err != nil { + return err + } + // send an error, but only if we haven't already errored for some reason + if chst.Status() != datatransfer.Failing && chst.Status() != datatransfer.Failed { + return m.channels.Error(chid, datatransfer.ErrIncomplete) + } + return nil } func (m *manager) receiveNewRequest( @@ -330,10 +346,14 @@ func (m *manager) processRevalidationResult(chid datatransfer.ChannelID, result func (m *manager) completeMessage(chid datatransfer.ChannelID) (datatransfer.Response, error) { var result datatransfer.VoucherResult var resultErr error + var handled bool _ = m.revalidators.Each(func(_ datatransfer.TypeIdentifier, _ encoding.Decoder, processor registry.Processor) error { revalidator := processor.(datatransfer.Revalidator) - result, resultErr = revalidator.OnComplete(chid) - return resultErr + handled, result, resultErr = revalidator.OnComplete(chid) + if handled { + return errors.New("stop processing") + } + return nil }) if result != nil { err := m.channels.NewVoucherResult(chid, result) diff --git a/impl/impl.go b/impl/impl.go index 7d684f30..a4e27abc 100644 --- a/impl/impl.go +++ b/impl/impl.go @@ -231,7 +231,7 @@ func (m *manager) PauseDataTransferChannel(ctx context.Context, chid datatransfe pausable, ok := m.transport.(datatransfer.PauseableTransport) if !ok { - return errors.New("unsupported") + return datatransfer.ErrUnsupported } err := pausable.PauseChannel(ctx, chid) @@ -252,7 +252,7 @@ func (m *manager) PauseDataTransferChannel(ctx context.Context, chid datatransfe func (m *manager) ResumeDataTransferChannel(ctx context.Context, chid datatransfer.ChannelID) error { pausable, ok := m.transport.(datatransfer.PauseableTransport) if !ok { - return errors.New("unsupported") + return datatransfer.ErrUnsupported } err := pausable.ResumeChannel(ctx, m.resumeMessage(chid), chid) diff --git a/impl/integration_test.go b/impl/integration_test.go index 5f64cbb0..fdb0e636 100644 --- a/impl/integration_test.go +++ b/impl/integration_test.go @@ -593,21 +593,21 @@ type retrievalRevalidator struct { finalVoucher datatransfer.VoucherResult } -func (r *retrievalRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (datatransfer.VoucherResult, error) { +func (r *retrievalRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (bool, datatransfer.VoucherResult, error) { r.dataSoFar += additionalBytesSent if r.providerPausePoint < len(r.pausePoints) && r.dataSoFar >= r.pausePoints[r.providerPausePoint] { r.providerPausePoint++ - return testutil.NewFakeDTType(), datatransfer.ErrPause + return true, testutil.NewFakeDTType(), datatransfer.ErrPause } - return nil, nil + return true, nil, nil } -func (r *retrievalRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (datatransfer.VoucherResult, error) { - return nil, nil +func (r *retrievalRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (bool, datatransfer.VoucherResult, error) { + return false, nil, nil } -func (r *retrievalRevalidator) OnComplete(chid datatransfer.ChannelID) (datatransfer.VoucherResult, error) { - return r.finalVoucher, datatransfer.ErrPause +func (r *retrievalRevalidator) OnComplete(chid datatransfer.ChannelID) (bool, datatransfer.VoucherResult, error) { + return true, r.finalVoucher, datatransfer.ErrPause } func TestSimulatedRetrievalFlow(t *testing.T) { @@ -890,6 +890,83 @@ func TestPauseAndResume(t *testing.T) { } } +func TestUnrecognizedVoucherRoundTrip(t *testing.T) { + ctx := context.Background() + testCases := map[string]bool{ + "push requests": false, + "pull requests": true, + } + for testCase, isPull := range testCases { + t.Run(testCase, func(t *testing.T) { + // ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + // defer cancel() + + gsData := testutil.NewGraphsyncTestingData(ctx, t) + host1 := gsData.Host1 // initiator, data sender + host2 := gsData.Host2 // data recipient + + tp1 := gsData.SetupGSTransportHost1() + tp2 := gsData.SetupGSTransportHost2() + + dt1, err := NewDataTransfer(gsData.DtDs1, gsData.DtNet1, tp1, gsData.StoredCounter1) + require.NoError(t, err) + err = dt1.Start(ctx) + require.NoError(t, err) + dt2, err := NewDataTransfer(gsData.DtDs2, gsData.DtNet2, tp2, gsData.StoredCounter2) + require.NoError(t, err) + err = dt2.Start(ctx) + require.NoError(t, err) + + finished := make(chan struct{}, 2) + errChan := make(chan string, 2) + opened := make(chan struct{}, 2) + var subscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { + if channelState.Status() == datatransfer.Failed { + finished <- struct{}{} + } + if event.Code == datatransfer.Error { + errChan <- channelState.Message() + } + if event.Code == datatransfer.Open { + opened <- struct{}{} + } + } + dt1.SubscribeToEvents(subscriber) + dt2.SubscribeToEvents(subscriber) + voucher := testutil.FakeDTType{Data: "applesauce"} + + root, _ := testutil.LoadUnixFSFile(ctx, t, gsData.DagService1) + rootCid := root.(cidlink.Link).Cid + + if isPull { + _, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + } else { + _, err = dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) + } + require.NoError(t, err) + opens := 0 + var errMessages []string + finishes := 0 + for opens < 1 || finishes < 1 { + select { + case <-ctx.Done(): + t.Fatal("Did not complete succcessful data transfer") + case <-finished: + finishes++ + case <-opened: + opens++ + case errMessage := <-errChan: + require.Equal(t, errMessage, datatransfer.ErrRejected.Error()) + errMessages = append(errMessages, errMessage) + if len(errMessages) > 1 { + t.Fatal("too many errors") + } + } + } + }) + } +} + func TestDataTransferSubscribing(t *testing.T) { // create network ctx := context.Background() diff --git a/impl/responding_test.go b/impl/responding_test.go index 45c4a8ed..546ed103 100644 --- a/impl/responding_test.go +++ b/impl/responding_test.go @@ -482,6 +482,27 @@ func TestDataTransferResponding(t *testing.T) { require.False(t, response.IsPaused()) }, }, + "validated, incomplete response": { + expectedEvents: []datatransfer.EventCode{ + datatransfer.Open, + datatransfer.NewVoucherResult, + datatransfer.Accept, + datatransfer.Error, + datatransfer.CleanupComplete, + }, + configureValidator: func(sv *testutil.StubbedValidator) { + sv.ExpectSuccessPull() + sv.StubResult(testutil.NewFakeDTType()) + }, + configureRevalidator: func(srv *testutil.StubbedRevalidator) { + }, + verify: func(t *testing.T, h *receiverHarness) { + _, err := h.transport.EventHandler.OnRequestReceived(channelID(h.id, h.peers), h.pullRequest) + require.NoError(t, err) + err = h.transport.EventHandler.OnChannelCompleted(channelID(h.id, h.peers), false) + require.NoError(t, err) + }, + }, "new push request, customized transport": { expectedEvents: []datatransfer.EventCode{datatransfer.Open, datatransfer.NewVoucherResult, datatransfer.Accept}, configureValidator: func(sv *testutil.StubbedValidator) { diff --git a/manager.go b/manager.go index 64210f80..2f00bae7 100644 --- a/manager.go +++ b/manager.go @@ -31,20 +31,30 @@ type Revalidator interface { // Revalidate revalidates a request with a new voucher Revalidate(channelID ChannelID, voucher Voucher) (VoucherResult, error) // OnPullDataSent is called on the responder side when more bytes are sent - // for a given pull request. It should return a VoucherResult + ErrPause to + // for a given pull request. The first value indicates whether the request was + // recognized by this revalidator and should be considered 'handled'. If true, + // the remaining two values are interpreted. If 'false' the request is passed on + // to the next revalidators. + // It should return a VoucherResult + ErrPause to // request revalidation or nil to continue uninterrupted, - // other errors will terminate the request - OnPullDataSent(chid ChannelID, additionalBytesSent uint64) (VoucherResult, error) + // other errors will terminate the request. + OnPullDataSent(chid ChannelID, additionalBytesSent uint64) (bool, VoucherResult, error) // OnPushDataReceived is called on the responder side when more bytes are received - // for a given push request. It should return a VoucherResult + ErrPause to + // for a given push request. The first value indicates whether the request was + // recognized by this revalidator and should be considered 'handled'. If true, + // the remaining two values are interpreted. If 'false' the request is passed on + // to the next revalidators. It should return a VoucherResult + ErrPause to // request revalidation or nil to continue uninterrupted, // other errors will terminate the request - OnPushDataReceived(chid ChannelID, additionalBytesReceived uint64) (VoucherResult, error) + OnPushDataReceived(chid ChannelID, additionalBytesReceived uint64) (bool, VoucherResult, error) // OnComplete is called to make a final request for revalidation -- often for the - // purpose of settlement. + // purpose of settlement. The first value indicates whether the request was + // recognized by this revalidator and should be considered 'handled'. If true, + // the remaining two values are interpreted. If 'false' the request is passed on + // to the next revalidators. // if VoucherResult is non nil, the request will enter a settlement phase awaiting // a final update - OnComplete(chid ChannelID) (VoucherResult, error) + OnComplete(chid ChannelID) (bool, VoucherResult, error) } // TransportConfigurer provides a mechanism to provide transport specific configuration for a given voucher type diff --git a/testutil/stubbedvalidator.go b/testutil/stubbedvalidator.go index bce966b0..917cea65 100644 --- a/testutil/stubbedvalidator.go +++ b/testutil/stubbedvalidator.go @@ -165,21 +165,21 @@ func NewStubbedRevalidator() *StubbedRevalidator { } // OnPullDataSent returns a stubbed result for checking when pull data is sent -func (srv *StubbedRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (datatransfer.VoucherResult, error) { +func (srv *StubbedRevalidator) OnPullDataSent(chid datatransfer.ChannelID, additionalBytesSent uint64) (bool, datatransfer.VoucherResult, error) { srv.didPullCheck = true - return srv.revalidationResult, srv.pullCheckError + return srv.expectPullCheck, srv.revalidationResult, srv.pullCheckError } // OnPushDataReceived returns a stubbed result for checking when push data is received -func (srv *StubbedRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (datatransfer.VoucherResult, error) { +func (srv *StubbedRevalidator) OnPushDataReceived(chid datatransfer.ChannelID, additionalBytesReceived uint64) (bool, datatransfer.VoucherResult, error) { srv.didPushCheck = true - return srv.revalidationResult, srv.pushCheckError + return srv.expectPushCheck, srv.revalidationResult, srv.pushCheckError } // OnComplete returns a stubbed result for checking when the requests completes -func (srv *StubbedRevalidator) OnComplete(chid datatransfer.ChannelID) (datatransfer.VoucherResult, error) { +func (srv *StubbedRevalidator) OnComplete(chid datatransfer.ChannelID) (bool, datatransfer.VoucherResult, error) { srv.didComplete = true - return srv.revalidationResult, srv.completeError + return srv.expectComplete, srv.revalidationResult, srv.completeError } // Revalidate returns a stubbed result for revalidating a request