diff --git a/jetstream/ordered.go b/jetstream/ordered.go index 469624477..85b7ea9e9 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -69,7 +69,10 @@ const ( consumerTypeFetch ) -var errOrderedSequenceMismatch = errors.New("sequence mismatch") +var ( + errOrderedSequenceMismatch = errors.New("sequence mismatch") + errOrderedConsumerClosed = errors.New("ordered consumer closed") +) // Consume can be used to continuously receive messages and handle them // with the provided callback function. Consume cannot be used concurrently @@ -142,6 +145,9 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt select { case <-c.doReset: if err := c.reset(); err != nil { + if errors.Is(err, errOrderedConsumerClosed) { + continue + } c.errHandler(c.serial)(c.currentSub, err) } if c.withStopAfter { @@ -173,6 +179,12 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt c.Unlock() } case <-sub.done: + s := sub.consumer.currentSub + if s != nil { + sub.consumer.Lock() + s.Stop() + sub.consumer.Unlock() + } return case msgsLeft, ok := <-c.stopAfterMsgsLeft: if !ok { @@ -276,6 +288,9 @@ func (s *orderedSubscription) Next() (Msg, error) { s.opts[len(s.opts)-1] = StopAfter(s.consumer.stopAfter) } if err := s.consumer.reset(); err != nil { + if errors.Is(err, errOrderedConsumerClosed) { + return nil, ErrMsgIteratorClosed + } return nil, err } cc, err := s.consumer.currentConsumer.Messages(s.opts...) @@ -297,6 +312,9 @@ func (s *orderedSubscription) Next() (Msg, error) { dseq := meta.Sequence.Consumer if dseq != s.consumer.cursor.deliverSeq+1 { if err := s.consumer.reset(); err != nil { + if errors.Is(err, errOrderedConsumerClosed) { + return nil, ErrMsgIteratorClosed + } return nil, err } cc, err := s.consumer.currentConsumer.Messages(s.opts...) @@ -318,7 +336,9 @@ func (s *orderedSubscription) Stop() { } s.consumer.Lock() defer s.consumer.Unlock() - s.consumer.currentSub.Stop() + if s.consumer.currentSub != nil { + s.consumer.currentSub.Stop() + } close(s.done) } @@ -326,9 +346,11 @@ func (s *orderedSubscription) Drain() { if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return } - s.consumer.currentConsumer.Lock() - defer s.consumer.currentConsumer.Unlock() - s.consumer.currentSub.Drain() + if s.consumer.currentSub != nil { + s.consumer.currentConsumer.Lock() + s.consumer.currentSub.Drain() + s.consumer.currentConsumer.Unlock() + } close(s.done) } @@ -504,7 +526,7 @@ func (c *orderedConsumer) reset() error { err = retryWithBackoff(func(attempt int) (bool, error) { isClosed := atomic.LoadUint32(&c.subscription.closed) == 1 if isClosed { - return false, nil + return false, errOrderedConsumerClosed } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -512,7 +534,6 @@ func (c *orderedConsumer) reset() error { if err != nil { return true, err } - c.currentConsumer = cons.(*pullConsumer) return false, nil }, backoffOpts) if err != nil { diff --git a/jetstream/test/kv_test.go b/jetstream/test/kv_test.go index 26a038c7b..010a7d5b3 100644 --- a/jetstream/test/kv_test.go +++ b/jetstream/test/kv_test.go @@ -967,6 +967,7 @@ func TestKeyValueListKeys(t *testing.T) { func TestKeyValueCrossAccounts(t *testing.T) { conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 jetstream: enabled accounts: { A: { diff --git a/jetstream/test/ordered_test.go b/jetstream/test/ordered_test.go index 01a67a8ba..522c92196 100644 --- a/jetstream/test/ordered_test.go +++ b/jetstream/test/ordered_test.go @@ -539,6 +539,45 @@ func TestOrderedConsumerConsume(t *testing.T) { cc.Drain() wg.Wait() }) + + t.Run("stop consume during reset", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + for i := 0; i < 10; i++ { + c, err := s.OrderedConsumer(context.Background(), jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + cc, err := c.Consume(func(msg jetstream.Msg) { + msg.Ack() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + cc.Stop() + time.Sleep(50 * time.Millisecond) + } + }) } func TestOrderedConsumerMessages(t *testing.T) {