diff --git a/nats.go b/nats.go index c8ab3bd49..bdec2cfd2 100644 --- a/nats.go +++ b/nats.go @@ -1,4 +1,4 @@ -// Copyright 2012-2023 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -607,14 +607,17 @@ type Subscription struct { // For holding information about a JetStream consumer. jsi *jsSub - delivered uint64 - max uint64 - conn *Conn - mcb MsgHandler - mch chan *Msg - closed bool - sc bool - connClosed bool + delivered uint64 + max uint64 + conn *Conn + mcb MsgHandler + mch chan *Msg + closed bool + sc bool + connClosed bool + draining bool + status SubStatus + statListeners map[chan SubStatus][]SubStatus // Type of Subscription typ SubscriptionType @@ -635,6 +638,30 @@ type Subscription struct { dropped int } +// Status represents the state of the connection. +type SubStatus int + +const ( + SubscriptionActive = SubStatus(iota) + SubscriptionDraining + SubscriptionClosed + SubscriptionSlowConsumer +) + +func (s SubStatus) String() string { + switch s { + case SubscriptionActive: + return "Active" + case SubscriptionDraining: + return "Draining" + case SubscriptionClosed: + return "Closed" + case SubscriptionSlowConsumer: + return "SlowConsumer" + } + return "unknown status" +} + // Msg represents a message delivered by NATS. This structure is used // by Subscribers and PublishMsg(). // @@ -3291,6 +3318,9 @@ func (nc *Conn) processMsg(data []byte) { } // Clear any SlowConsumer status. + if sub.sc { + sub.changeSubStatus(SubscriptionActive) + } sub.sc = false sub.mu.Unlock() @@ -3314,8 +3344,9 @@ slowConsumer: sub.pMsgs-- sub.pBytes -= len(m.Data) } - sub.mu.Unlock() if sc { + sub.changeSubStatus(SubscriptionSlowConsumer) + sub.mu.Unlock() // Now we need connection's lock and we may end-up in the situation // that we were trying to avoid, except that in this case, the client // is already experiencing client-side slow consumer situation. @@ -3325,6 +3356,8 @@ slowConsumer: nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, sub, ErrSlowConsumer) }) } nc.mu.Unlock() + } else { + sub.mu.Unlock() } } @@ -4294,6 +4327,7 @@ func (nc *Conn) subscribeLocked(subj, queue string, cb MsgHandler, ch chan *Msg, nc.kickFlusher() } + sub.changeSubStatus(SubscriptionActive) return sub, nil } @@ -4337,6 +4371,7 @@ func (nc *Conn) removeSub(s *Subscription) { } // Mark as invalid s.closed = true + s.changeSubStatus(SubscriptionClosed) if s.pCond != nil { s.pCond.Broadcast() } @@ -4406,6 +4441,91 @@ func (s *Subscription) Drain() error { return conn.unsubscribe(s, 0, true) } +// IsDraining returns a boolean indicating whether the subscription +// is being drained. +// This will return false if the subscription has already been closed. +func (s *Subscription) IsDraining() bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + return s.draining +} + +// StatusChanged returns a channel on which given list of subscription status +// changes will be sent. If no status is provided, all status changes will be sent. +// Available statuses are SubscriptionActive, SubscriptionDraining, SubscriptionClosed, +// and SubscriptionSlowConsumer. +// The returned channel will be closed when the subscription is closed. +func (s *Subscription) StatusChanged(statuses ...SubStatus) <-chan SubStatus { + if len(statuses) == 0 { + statuses = []SubStatus{SubscriptionActive, SubscriptionDraining, SubscriptionClosed, SubscriptionSlowConsumer} + } + ch := make(chan SubStatus, 10) + for _, status := range statuses { + s.registerStatusChangeListener(status, ch) + // initial status + if status == s.status { + ch <- status + } + } + return ch +} + +// registerStatusChangeListener registers a channel waiting for a specific status change event. +// Status change events are non-blocking - if no receiver is waiting for the status change, +// it will not be sent on the channel. Closed channels are ignored. +func (s *Subscription) registerStatusChangeListener(status SubStatus, ch chan SubStatus) { + s.mu.Lock() + defer s.mu.Unlock() + if s.statListeners == nil { + s.statListeners = make(map[chan SubStatus][]SubStatus) + } + if _, ok := s.statListeners[ch]; !ok { + s.statListeners[ch] = make([]SubStatus, 0) + } + s.statListeners[ch] = append(s.statListeners[ch], status) +} + +// sendStatusEvent sends subscription status event to all channels. +// If there is no listener, sendStatusEvent +// will not block. Lock should be held entering. +func (s *Subscription) sendStatusEvent(status SubStatus) { + for ch, statuses := range s.statListeners { + if !containsStatus(statuses, status) { + continue + } + // only send event if someone's listening + select { + case ch <- status: + default: + } + if status == SubscriptionClosed { + close(ch) + } + } +} + +func containsStatus(statuses []SubStatus, status SubStatus) bool { + for _, s := range statuses { + if s == status { + return true + } + } + return false +} + +// changeSubStatus changes subscription status and sends events +// to all listeners. Lock should be held entering. +func (s *Subscription) changeSubStatus(status SubStatus) { + if s == nil { + return + } + s.sendStatusEvent(status) + s.status = status +} + // Unsubscribe will remove interest in the given subject. // // For a JetStream subscription, if the library has created the JetStream @@ -4444,6 +4564,11 @@ func (s *Subscription) Unsubscribe() error { // checkDrained will watch for a subscription to be fully drained // and then remove it. func (nc *Conn) checkDrained(sub *Subscription) { + defer func() { + sub.mu.Lock() + defer sub.mu.Unlock() + sub.draining = false + }() if nc == nil || sub == nil { return } @@ -4553,6 +4678,10 @@ func (nc *Conn) unsubscribe(sub *Subscription, max int, drainMode bool) error { } if drainMode { + s.mu.Lock() + s.draining = true + sub.changeSubStatus(SubscriptionDraining) + s.mu.Unlock() go nc.checkDrained(sub) } @@ -4655,6 +4784,7 @@ func (s *Subscription) validateNextMsgState(pullSubInternal bool) error { return ErrSyncSubRequired } if s.sc { + s.changeSubStatus(SubscriptionActive) s.sc = false return ErrSlowConsumer } diff --git a/test/drain_test.go b/test/drain_test.go index 1168f617f..c53305e17 100644 --- a/test/drain_test.go +++ b/test/drain_test.go @@ -1,4 +1,4 @@ -// Copyright 2018-2023 The NATS Authors +// Copyright 2018-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -55,6 +55,9 @@ func TestDrain(t *testing.T) { // Drain it and make sure we receive all messages. sub.Drain() + if !sub.IsDraining() { + t.Fatalf("Expected to be draining") + } select { case <-done: break @@ -64,6 +67,10 @@ func TestDrain(t *testing.T) { t.Fatalf("Did not receive all messages: %d of %d", r, expected) } } + time.Sleep(100 * time.Millisecond) + if sub.IsDraining() { + t.Fatalf("Expected to be done draining") + } } func TestDrainQueueSub(t *testing.T) { diff --git a/test/js_test.go b/test/js_test.go index 900792a34..540ae41c9 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2023 The NATS Authors +// Copyright 2020-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/test/sub_test.go b/test/sub_test.go index c359639df..0bf2880c1 100644 --- a/test/sub_test.go +++ b/test/sub_test.go @@ -1,4 +1,4 @@ -// Copyright 2013-2023 The NATS Authors +// Copyright 2013-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -14,6 +14,7 @@ package test import ( + "errors" "fmt" "sync" "sync/atomic" @@ -568,7 +569,7 @@ func TestAsyncErrHandler(t *testing.T) { if s != sub { t.Fatal("Did not receive proper subscription") } - if e != nats.ErrSlowConsumer { + if !errors.Is(e, nats.ErrSlowConsumer) { t.Fatalf("Did not receive proper error: %v vs %v", e, nats.ErrSlowConsumer) } // Suppress additional calls @@ -636,7 +637,7 @@ func TestAsyncErrHandlerChanSubscription(t *testing.T) { nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) { atomic.AddInt64(&aeCalled, 1) - if e != nats.ErrSlowConsumer { + if !errors.Is(e, nats.ErrSlowConsumer) { t.Fatalf("Did not receive proper error: %v vs %v", e, nats.ErrSlowConsumer) } @@ -1614,3 +1615,137 @@ func TestSubscribe_ClosedHandler(t *testing.T) { t.Fatal("Did not receive closed callback") } } + +func TestSubscriptionEvents(t *testing.T) { + + waitForStatus := func(t *testing.T, ch <-chan nats.SubStatus, expected nats.SubStatus) { + t.Helper() + select { + case s := <-ch: + if s != expected { + t.Fatalf("Expected status: %s; got: %s", expected, s) + } + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for status %q", expected) + } + } + t.Run("default events", func(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + // disable slow consumer prints + nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) {}) + defer nc.Close() + + blockChan := make(chan struct{}) + sub, err := nc.Subscribe("foo", func(_ *nats.Msg) { + // block in subscription callback + // to force slow consumer + <-blockChan + }) + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + sub.SetPendingLimits(10, 1024) + status := sub.StatusChanged() + + // initial status + waitForStatus(t, status, nats.SubscriptionActive) + + for i := 0; i < 11; i++ { + nc.Publish("foo", []byte("Hello")) + } + waitForStatus(t, status, nats.SubscriptionSlowConsumer) + close(blockChan) + + sub.Drain() + + waitForStatus(t, status, nats.SubscriptionDraining) + + waitForStatus(t, status, nats.SubscriptionClosed) + }) + + t.Run("slow consumer event only", func(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + defer nc.Close() + + blockChan := make(chan struct{}) + sub, err := nc.Subscribe("foo", func(_ *nats.Msg) { + // block in subscription callback + // to force slow consumer + <-blockChan + }) + // disable slow consumer prints + nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) {}) + defer sub.Unsubscribe() + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + sub.SetPendingLimits(10, 1024) + status := sub.StatusChanged(nats.SubscriptionSlowConsumer) + + for i := 0; i < 20; i++ { + nc.Publish("foo", []byte("Hello")) + } + waitForStatus(t, status, nats.SubscriptionSlowConsumer) + close(blockChan) + + // now try with sync sub + sub, err = nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer sub.Unsubscribe() + sub.SetPendingLimits(10, 1024) + status = sub.StatusChanged(nats.SubscriptionSlowConsumer) + + for i := 0; i < 20; i++ { + nc.Publish("foo", []byte("Hello")) + } + waitForStatus(t, status, nats.SubscriptionSlowConsumer) + }) + + t.Run("do not block channel if it's not read", func(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + // disable slow consumer prints + nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) {}) + defer nc.Close() + + blockChan := make(chan struct{}) + sub, err := nc.Subscribe("foo", func(_ *nats.Msg) { + // block in subscription callback + // to force slow consumer + <-blockChan + }) + defer sub.Unsubscribe() + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + sub.SetPendingLimits(10, 1024) + status := sub.StatusChanged() + waitForStatus(t, status, nats.SubscriptionActive) + + // chan length is 10, so make sure we switch state more times + for i := 0; i < 20; i++ { + // subscription will enter slow consumer state + for i := 0; i < 11; i++ { + nc.Publish("foo", []byte("Hello")) + } + + // messages flow normally, status flips to active + for i := 0; i < 10; i++ { + nc.Publish("foo", []byte("Hello")) + blockChan <- struct{}{} + } + } + // do not read from subscription + close(blockChan) + }) +}