Skip to content

Commit

Permalink
[ADDED] ConsumeContext.Closed() method for waiting for consume to be …
Browse files Browse the repository at this point in the history
…closed/drained (#1691)

Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Dec 13, 2024
1 parent f188ceb commit bf49163
Show file tree
Hide file tree
Showing 4 changed files with 353 additions and 15 deletions.
61 changes: 46 additions & 15 deletions jetstream/ordered.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ type (
cfg *OrderedConsumerConfig
stream string
currentConsumer *pullConsumer
currentSub ConsumeContext
currentSub *pullSubscription
cursor cursor
namePrefix string
serial int
consumerType consumerType
doReset chan struct{}
resetInProgress uint32
resetInProgress atomic.Uint32
userErrHandler ConsumeErrHandlerFunc
stopAfter int
stopAfterMsgsLeft chan int
Expand All @@ -52,7 +52,7 @@ type (
consumer *orderedConsumer
opts []PullMessagesOpt
done chan struct{}
closed uint32
closed atomic.Uint32
}

cursor struct {
Expand Down Expand Up @@ -138,7 +138,7 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
if err != nil {
return nil, err
}
c.currentSub = cc
c.currentSub = cc.(*pullSubscription)

go func() {
for {
Expand Down Expand Up @@ -175,7 +175,7 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt
c.errHandler(c.serial)(cc, err)
} else {
c.Lock()
c.currentSub = cc
c.currentSub = cc.(*pullSubscription)
c.Unlock()
}
case <-sub.done:
Expand Down Expand Up @@ -210,8 +210,8 @@ func (c *orderedConsumer) errHandler(serial int) func(cc ConsumeContext, err err
errors.Is(err, ErrConsumerDeleted) ||
errors.Is(err, errConnected) {
// only reset if serial matches the current consumer serial and there is no reset in progress
if serial == c.serial && atomic.LoadUint32(&c.resetInProgress) == 0 {
atomic.StoreUint32(&c.resetInProgress, 1)
if serial == c.serial && c.resetInProgress.Load() == 0 {
c.resetInProgress.Store(1)
c.doReset <- struct{}{}
}
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er
if err != nil {
return nil, err
}
c.currentSub = cc
c.currentSub = cc.(*pullSubscription)

sub := &orderedSubscription{
consumer: c,
Expand All @@ -270,7 +270,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er

func (s *orderedSubscription) Next() (Msg, error) {
for {
msg, err := s.consumer.currentSub.(*pullSubscription).Next()
msg, err := s.consumer.currentSub.Next()
if err != nil {
if errors.Is(err, ErrMsgIteratorClosed) {
s.Stop()
Expand All @@ -297,7 +297,7 @@ func (s *orderedSubscription) Next() (Msg, error) {
if err != nil {
return nil, err
}
s.consumer.currentSub = cc
s.consumer.currentSub = cc.(*pullSubscription)
continue
}

Expand All @@ -321,7 +321,7 @@ func (s *orderedSubscription) Next() (Msg, error) {
if err != nil {
return nil, err
}
s.consumer.currentSub = cc
s.consumer.currentSub = cc.(*pullSubscription)
continue
}
s.consumer.cursor.deliverSeq = dseq
Expand All @@ -331,7 +331,7 @@ func (s *orderedSubscription) Next() (Msg, error) {
}

func (s *orderedSubscription) Stop() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
if !s.closed.CompareAndSwap(0, 1) {
return
}
s.consumer.Lock()
Expand All @@ -343,7 +343,7 @@ func (s *orderedSubscription) Stop() {
}

func (s *orderedSubscription) Drain() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
if !s.closed.CompareAndSwap(0, 1) {
return
}
if s.consumer.currentSub != nil {
Expand All @@ -354,6 +354,37 @@ func (s *orderedSubscription) Drain() {
close(s.done)
}

// Closed returns a channel that is closed when the consuming is
// fully stopped/drained. When the channel is closed, no more messages
// will be received and processing is complete.
func (s *orderedSubscription) Closed() <-chan struct{} {
s.consumer.Lock()
defer s.consumer.Unlock()
closedCh := make(chan struct{})

go func() {
for {
s.consumer.Lock()
if s.consumer.currentSub == nil {
return
}

closed := s.consumer.currentSub.Closed()
s.consumer.Unlock()

// wait until the underlying pull consumer is closed
<-closed
// if the subscription is closed and ordered consumer is closed as well,
// send a signal that the Consume() is fully stopped
if s.closed.Load() == 1 {
close(closedCh)
return
}
}
}()
return closedCh
}

// Fetch is used to retrieve up to a provided number of messages from a
// stream. This method will always send a single request and wait until
// either all messages are retrieved or request times out.
Expand Down Expand Up @@ -495,7 +526,7 @@ func serialNumberFromConsumer(name string) int {
func (c *orderedConsumer) reset() error {
c.Lock()
defer c.Unlock()
defer atomic.StoreUint32(&c.resetInProgress, 0)
defer c.resetInProgress.Store(0)
if c.currentConsumer != nil {
c.currentConsumer.Lock()
if c.currentSub != nil {
Expand Down Expand Up @@ -524,7 +555,7 @@ func (c *orderedConsumer) reset() error {
cancel: c.subscription.done,
}
err = retryWithBackoff(func(attempt int) (bool, error) {
isClosed := atomic.LoadUint32(&c.subscription.closed) == 1
isClosed := c.subscription.closed.Load() == 1
if isClosed {
return false, errOrderedConsumerClosed
}
Expand Down
30 changes: 30 additions & 0 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ type (
// Drain unsubscribes from the stream and cancels subscription.
// All messages that are already in the buffer will be processed in callback function.
Drain()

// Closed returns a channel that is closed when the consuming is
// fully stopped/drained. When the channel is closed, no more messages
// will be received and processing is complete.
Closed() <-chan struct{}
}

// MessageHandler is a handler function used as callback in [Consume].
Expand Down Expand Up @@ -125,6 +130,7 @@ type (
fetchNext chan *pullRequest
consumeOpts *consumeOpts
delivered int
closedCh chan struct{}
}

pendingMsgs struct {
Expand Down Expand Up @@ -257,6 +263,12 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) (
return func(subject string) {
p.subs.Delete(sid)
sub.draining.CompareAndSwap(1, 0)
sub.Lock()
if sub.closedCh != nil {
close(sub.closedCh)
sub.closedCh = nil
}
sub.Unlock()
}
}(sub.id))

Expand Down Expand Up @@ -649,6 +661,24 @@ func (s *pullSubscription) Drain() {
}
}

// Closed returns a channel that is closed when consuming is
// fully stopped/drained. When the channel is closed, no more messages
// will be received and processing is complete.
func (s *pullSubscription) Closed() <-chan struct{} {
s.Lock()
defer s.Unlock()
closedCh := s.closedCh
if closedCh == nil {
closedCh = make(chan struct{})
s.closedCh = closedCh
}
if !s.subscription.IsValid() {
close(s.closedCh)
s.closedCh = nil
}
return closedCh
}

// Fetch sends a single request to retrieve given number of messages.
// It will wait up to provided expiry time if not all messages are available.
func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) {
Expand Down
129 changes: 129 additions & 0 deletions jetstream/test/ordered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,135 @@ func TestOrderedConsumerConsume(t *testing.T) {
time.Sleep(50 * time.Millisecond)
}
})

t.Run("wait for closed after drain", func(t *testing.T) {
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs := make([]jetstream.Msg, 0)
lock := sync.Mutex{}
publishTestMsgs(t, js)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
lock.Lock()
msgs = append(msgs, msg)
lock.Unlock()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
closed := cc.Closed()
time.Sleep(100 * time.Millisecond)
if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
publishTestMsgs(t, js)

// wait for the consumer to be recreated before calling drain
for i := 0; i < 5; i++ {
_, err = c.Info(ctx)
if err != nil {
if errors.Is(err, jetstream.ErrConsumerNotFound) {
time.Sleep(100 * time.Millisecond)
continue
}
t.Fatalf("Unexpected error: %v", err)
}
break
}

cc.Drain()

select {
case <-closed:
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for consume to be closed")
}

if len(msgs) != 2*len(testMsgs) {
t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", 2*len(testMsgs), len(msgs))
}
})
}
})

t.Run("wait for closed on already closed consume", func(t *testing.T) {
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
msgs := make([]jetstream.Msg, 0)
lock := sync.Mutex{}
publishTestMsgs(t, js)
cc, err := c.Consume(func(msg jetstream.Msg) {
time.Sleep(50 * time.Millisecond)
msg.Ack()
lock.Lock()
msgs = append(msgs, msg)
lock.Unlock()
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
time.Sleep(100 * time.Millisecond)
if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

cc.Stop()

time.Sleep(100 * time.Millisecond)

select {
case <-cc.Closed():
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for consume to be closed")
}
})
}
})
}

func TestOrderedConsumerMessages(t *testing.T) {
Expand Down
Loading

0 comments on commit bf49163

Please sign in to comment.