Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Drain() infinite loop and add test for concurrent Next() calls #1525

Merged
merged 6 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ type (
closed uint32
draining uint32
done chan struct{}
drained chan struct{}
connStatusChanged chan nats.Status
fetchNext chan *pullRequest
consumeOpts *consumeOpts
Expand Down Expand Up @@ -476,7 +475,6 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error
id: consumeID,
consumer: p,
done: make(chan struct{}, 1),
drained: make(chan struct{}, 1),
msgs: msgs,
errs: make(chan error, 1),
fetchNext: make(chan *pullRequest, 1),
Expand Down Expand Up @@ -560,12 +558,6 @@ func (s *pullSubscription) Next() (Msg, error) {
for {
s.checkPending()
select {
case <-s.done:
drainMode := atomic.LoadUint32(&s.draining) == 1
if drainMode {
continue
}
return nil, ErrMsgIteratorClosed
case msg, ok := <-s.msgs:
if !ok {
// if msgs channel is closed, it means that subscription was either drained or stopped
Expand Down Expand Up @@ -914,8 +906,11 @@ func (s *pullSubscription) scheduleHeartbeatCheck(dur time.Duration) *hbMonitor
}

func (s *pullSubscription) cleanup() {
s.Lock()
defer s.Unlock()
// For now this function does not need to hold the lock.
// Holding the lock here might cause a deadlock if Next()
// is already holding the lock and waiting.
// The fields that are read (subscription, hbMonitor)
// are read only (Only written on creation of pullSubscription).
if s.subscription == nil || !s.subscription.IsValid() {
return
}
Expand Down
217 changes: 166 additions & 51 deletions jetstream/test/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,111 @@ func TestPullConsumerMessages(t *testing.T) {
}
})

t.Run("with auto unsubscribe concurrent", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "test", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

it, err := c.Messages(jetstream.StopAfter(50), jetstream.PullMaxMessages(40))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

for i := 0; i < 100; i++ {
if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil {
t.Fatalf("Unexpected error during publish: %s", err)
}
}

var mu sync.Mutex // Mutex to guard the msgs slice.
msgs := make([]jetstream.Msg, 0)
var wg sync.WaitGroup

wg.Add(50)
for i := 0; i < 50; i++ {
go func() {
defer wg.Done()

msg, err := it.Next()
if err != nil {
return
}

ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := msg.DoubleAck(ctx); err == nil {
// Only append the msg if ack is successful.
mu.Lock()
msgs = append(msgs, msg)
mu.Unlock()
}
}()
}

wg.Wait()

// Call Next in a goroutine so we can timeout if it doesn't return.
errs := make(chan error)
go func() {
// This call should return the error ErrMsgIteratorClosed.
_, err := it.Next()
errs <- err
}()

timer := time.NewTimer(5 * time.Second)
defer timer.Stop()

select {
case <-timer.C:
t.Fatal("Timed out waiting for Next() to return")
case err := <-errs:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Unexpected error: %v", err)
}
}

mu.Lock()
wantLen, gotLen := 50, len(msgs)
mu.Unlock()
if wantLen != gotLen {
t.Fatalf("Unexpected received message count; want %d; got %d", wantLen, gotLen)
}

ci, err := c.Info(ctx)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if ci.NumPending != 50 {
t.Fatalf("Unexpected number of pending messages; want 50; got %d", ci.NumPending)
}
if ci.NumAckPending != 0 {
t.Fatalf("Unexpected number of ack pending messages; want 0; got %d", ci.NumAckPending)
}
if ci.NumWaiting != 0 {
t.Fatalf("Unexpected number of waiting pull requests; want 0; got %d", ci.NumWaiting)
}
})

t.Run("create iterator, stop, then create again", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
Expand Down Expand Up @@ -1293,69 +1398,79 @@ func TestPullConsumerMessages(t *testing.T) {
})

t.Run("with graceful shutdown", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)

nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
cases := map[string]func(jetstream.MessagesContext){
"stop": func(mc jetstream.MessagesContext) { mc.Stop() },
"drain": func(mc jetstream.MessagesContext) { mc.Drain() },
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
for name, unsubscribe := range cases {
t.Run(name, func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

it, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

publishTestMsgs(t, nc)
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

errs := make(chan error)
msgs := make([]jetstream.Msg, 0)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

go func() {
for {
msg, err := it.Next()
it, err := c.Messages()
if err != nil {
errs <- err
return
t.Fatalf("Unexpected error: %v", err)
}
msg.Ack()
msgs = append(msgs, msg)
}
}()

time.Sleep(10 * time.Millisecond)
it.Stop() // Next() should return ErrMsgIteratorClosed
publishTestMsgs(t, nc)

timeout := time.NewTimer(5 * time.Second)
errs := make(chan error)
msgs := make([]jetstream.Msg, 0)

select {
case <-timeout.C:
t.Fatal("Timed out waiting for Next() to return after Stop()")
case err := <-errs:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Unexpected error: %v", err)
}
go func() {
for {
msg, err := it.Next()
if err != nil {
errs <- err
return
}
msg.Ack()
msgs = append(msgs, msg)
}
}()

if len(msgs) != len(testMsgs) {
t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
}
time.Sleep(10 * time.Millisecond)
unsubscribe(it) // Next() should return ErrMsgIteratorClosed

timer := time.NewTimer(5 * time.Second)
defer timer.Stop()

select {
case <-timer.C:
t.Fatal("Timed out waiting for Next() to return")
case err := <-errs:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Unexpected error: %v", err)
}

if len(msgs) != len(testMsgs) {
t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
}
}
})
}
})

Expand Down