diff --git a/async_processor.go b/async_processor.go index 660b03bc..2a0a9f0a 100644 --- a/async_processor.go +++ b/async_processor.go @@ -10,10 +10,11 @@ import ( type asyncProcessor struct { bufferSize int - running bool - buffer *ringbuffer.RingBuffer + running bool + buffer *ringbuffer.RingBuffer + stopError error - chError chan error + stopped chan struct{} } func (w *asyncProcessor) initialize() { @@ -22,22 +23,21 @@ func (w *asyncProcessor) initialize() { func (w *asyncProcessor) start() { w.running = true - w.chError = make(chan error) + w.stopped = make(chan struct{}) go w.run() } func (w *asyncProcessor) stop() { if w.running { w.buffer.Close() - <-w.chError + <-w.stopped w.running = false } } func (w *asyncProcessor) run() { - err := w.runInner() - w.chError <- err - close(w.chError) + w.stopError = w.runInner() + close(w.stopped) } func (w *asyncProcessor) runInner() error { diff --git a/async_processor_test.go b/async_processor_test.go new file mode 100644 index 00000000..2f881bd6 --- /dev/null +++ b/async_processor_test.go @@ -0,0 +1,24 @@ +package gortsplib + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAsyncProcessorStopAfterError(t *testing.T) { + p := &asyncProcessor{bufferSize: 8} + p.initialize() + + p.push(func() error { + return fmt.Errorf("ok") + }) + + p.start() + + <-p.stopped + require.EqualError(t, p.stopError, "ok") + + p.stop() +} diff --git a/client.go b/client.go index e1df861d..be014ed7 100644 --- a/client.go +++ b/client.go @@ -559,9 +559,9 @@ func (c *Client) runInner() error { return nil }() - chWriterError := func() chan error { - if c.writer != nil { - return c.writer.chError + chWriterError := func() chan struct{} { + if c.writer != nil && c.writer.running { + return c.writer.stopped } return nil }() @@ -637,8 +637,8 @@ func (c *Client) runInner() error { } c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) - case err := <-chWriterError: - return err + case <-chWriterError: + return c.writer.stopError case err := <-chReaderError: c.reader = nil diff --git a/client_play_test.go b/client_play_test.go index 87f86cbf..05a6dc61 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -1961,8 +1961,8 @@ func TestClientPlayPause(t *testing.T) { }) require.NoError(t, err2) - req, err = conn.ReadRequest() - require.NoError(t, err) + req, err2 = conn.ReadRequest() + require.NoError(t, err2) require.Equal(t, base.Play, req.Method) err2 = conn.WriteResponse(&base.Response{ diff --git a/server_session.go b/server_session.go index 549f388b..1c3c77bc 100644 --- a/server_session.go +++ b/server_session.go @@ -626,9 +626,9 @@ func (ss *ServerSession) run() { func (ss *ServerSession) runInner() error { for { - chWriterError := func() chan error { - if ss.writer != nil { - return ss.writer.chError + chWriterError := func() chan struct{} { + if ss.writer != nil && ss.writer.running { + return ss.writer.stopped } return nil }() @@ -729,8 +729,8 @@ func (ss *ServerSession) runInner() error { ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) - case err := <-chWriterError: - return err + case <-chWriterError: + return ss.writer.stopError case <-ss.ctx.Done(): return liberrors.ErrServerTerminated{} @@ -1306,7 +1306,6 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( } ss.writer.stop() - ss.writer = nil ss.timeDecoder = nil