diff --git a/channelmonitor/channelmonitor.go b/channelmonitor/channelmonitor.go index 02a29034..03e99252 100644 --- a/channelmonitor/channelmonitor.go +++ b/channelmonitor/channelmonitor.go @@ -38,7 +38,8 @@ type Monitor struct { } type Config struct { - // Max time to wait for other side to accept open channel request before attempting restart + // Max time to wait for other side to accept open channel request before attempting restart. + // Set to 0 to disable timeout. AcceptTimeout time.Duration // Debounce when restart is triggered by multiple errors RestartDebounce time.Duration @@ -47,7 +48,8 @@ type Config struct { // Number of times to try to restart before failing MaxConsecutiveRestarts uint32 // Max time to wait for the responder to send a Complete message once all - // data has been sent + // data has been sent. + // Set to 0 to disable timeout. CompleteTimeout time.Duration // Called when a restart completes successfully OnRestartComplete func(id datatransfer.ChannelID) @@ -71,14 +73,14 @@ func checkConfig(cfg *Config) { } prefix := "data-transfer channel monitor config " - if cfg.AcceptTimeout <= 0 { - panic(fmt.Sprintf(prefix+"AcceptTimeout is %s but must be > 0", cfg.AcceptTimeout)) + if cfg.AcceptTimeout < 0 { + panic(fmt.Sprintf(prefix+"AcceptTimeout is %s but must be >= 0", cfg.AcceptTimeout)) } if cfg.MaxConsecutiveRestarts == 0 { panic(fmt.Sprintf(prefix+"MaxConsecutiveRestarts is %d but must be > 0", cfg.MaxConsecutiveRestarts)) } - if cfg.CompleteTimeout <= 0 { - panic(fmt.Sprintf(prefix+"CompleteTimeout is %s but must be > 0", cfg.CompleteTimeout)) + if cfg.CompleteTimeout < 0 { + panic(fmt.Sprintf(prefix+"CompleteTimeout is %s but must be >= 0", cfg.CompleteTimeout)) } } @@ -269,6 +271,11 @@ func (mc *monitoredChannel) start() { // an Accept to our open channel request before the accept timeout. // Returns a function that can be used to cancel the timer. func (mc *monitoredChannel) watchForResponderAccept() func() { + // Check if the accept timeout is disabled + if mc.cfg.AcceptTimeout == 0 { + return func() {} + } + // Start a timer for the accept timeout timer := time.NewTimer(mc.cfg.AcceptTimeout) @@ -291,6 +298,11 @@ func (mc *monitoredChannel) watchForResponderAccept() func() { // Wait up to the configured timeout for the responder to send a Complete message func (mc *monitoredChannel) watchForResponderComplete() { + // Check if the complete timeout is disabled + if mc.cfg.CompleteTimeout == 0 { + return + } + // Start a timer for the complete timeout timer := time.NewTimer(mc.cfg.CompleteTimeout) defer timer.Stop() @@ -302,7 +314,7 @@ func (mc *monitoredChannel) watchForResponderComplete() { case <-timer.C: // Timer expired before we received a Complete message from the responder err := xerrors.Errorf("%s: timed out waiting %s for Complete message from remote peer", - mc.chid, mc.cfg.AcceptTimeout) + mc.chid, mc.cfg.CompleteTimeout) mc.closeChannelAndShutdown(err) } } diff --git a/channelmonitor/channelmonitor_test.go b/channelmonitor/channelmonitor_test.go index 70f95642..eefe3fc7 100644 --- a/channelmonitor/channelmonitor_test.go +++ b/channelmonitor/channelmonitor_test.go @@ -250,9 +250,11 @@ func TestChannelMonitorQueuedRestart(t *testing.T) { func TestChannelMonitorTimeouts(t *testing.T) { type testCase struct { - name string - expectAccept bool - expectComplete bool + name string + expectAccept bool + expectComplete bool + acceptTimeoutDisabled bool + completeTimeoutDisabled bool } testCases := []testCase{{ name: "accept in time", @@ -261,6 +263,10 @@ func TestChannelMonitorTimeouts(t *testing.T) { }, { name: "accept too late", expectAccept: false, + }, { + name: "disable accept timeout", + acceptTimeoutDisabled: true, + expectAccept: true, }, { name: "complete in time", expectAccept: true, @@ -269,6 +275,11 @@ func TestChannelMonitorTimeouts(t *testing.T) { name: "complete too late", expectAccept: true, expectComplete: false, + }, { + name: "disable complete timeout", + completeTimeoutDisabled: true, + expectAccept: true, + expectComplete: true, }} runTest := func(name string, isPush bool) { @@ -286,6 +297,12 @@ func TestChannelMonitorTimeouts(t *testing.T) { acceptTimeout := 10 * time.Millisecond completeTimeout := 10 * time.Millisecond + if tc.acceptTimeoutDisabled { + acceptTimeout = 0 + } + if tc.completeTimeoutDisabled { + completeTimeout = 0 + } m := NewMonitor(mockAPI, &Config{ AcceptTimeout: acceptTimeout, MaxConsecutiveRestarts: 1,