From 7458bb1a3e660eb15f3a01e256fe3c3b554e0362 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Tue, 12 Mar 2024 11:48:10 +0100 Subject: [PATCH] [IMPROVED] Fetch and FetchBatch for draining and closed subscriptions Signed-off-by: Piotr Piotrowski --- js.go | 16 +++++++- jserrors.go | 3 ++ test/js_test.go | 107 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 2 deletions(-) diff --git a/js.go b/js.go index 462fea17e..acdb7bd07 100644 --- a/js.go +++ b/js.go @@ -2861,7 +2861,13 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { } var hbTimer *time.Timer var hbErr error - if err == nil && len(msgs) < batch { + sub.mu.Lock() + subClosed := sub.closed || sub.draining + sub.mu.Unlock() + if subClosed { + err = ErrSubscriptionClosed + } + if err == nil && len(msgs) < batch && !subClosed { // For batch real size of 1, it does not make sense to set no_wait in // the request. noWait := batch-len(msgs) > 1 @@ -3129,8 +3135,14 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e result.msgs <- msg } } - if len(result.msgs) == batch || result.err != nil { + sub.mu.Lock() + subClosed := sub.closed || sub.draining + sub.mu.Unlock() + if len(result.msgs) == batch || result.err != nil || subClosed { close(result.msgs) + if subClosed && len(result.msgs) == 0 { + return nil, ErrSubscriptionClosed + } result.done <- struct{}{} return result, nil } diff --git a/jserrors.go b/jserrors.go index b5c968465..2a160405c 100644 --- a/jserrors.go +++ b/jserrors.go @@ -141,6 +141,9 @@ var ( // ErrNoHeartbeat is returned when no heartbeat is received from server when sending requests with pull consumer. ErrNoHeartbeat JetStreamError = &jsError{message: "no heartbeat received"} + // ErrSubscriptionClosed is returned when attempting to send pull request to a closed subscription + ErrSubscriptionClosed JetStreamError = &jsError{message: "subscription closed"} + // DEPRECATED: ErrInvalidDurableName is no longer returned and will be removed in future releases. // Use ErrInvalidConsumerName instead. ErrInvalidDurableName = errors.New("nats: invalid durable name") diff --git a/test/js_test.go b/test/js_test.go index 540ae41c9..722c59cf7 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -1239,6 +1239,64 @@ func TestPullSubscribeFetchWithHeartbeat(t *testing.T) { } } +func TestPullSubscribeFetchDrain(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + defer js.PurgeStream("TEST") + sub, err := js.PullSubscribe("foo", "") + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for i := 0; i < 100; i++ { + if _, err := js.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + } + // fill buffer with messages + cinfo, err := sub.ConsumerInfo() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + nextSubject := fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.TEST.%s", cinfo.Name) + replySubject := strings.Replace(sub.Subject, "*", "abc", 1) + payload := `{"batch":10,"no_wait":true}` + if err := nc.PublishRequest(nextSubject, replySubject, []byte(payload)); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + time.Sleep(100 * time.Millisecond) + + // now drain the subscription, messages should be in the buffer + sub.Drain() + msgs, err := sub.Fetch(100) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for _, msg := range msgs { + msg.Ack() + } + if len(msgs) != 10 { + t.Fatalf("Expected %d messages; got: %d", 10, len(msgs)) + } + + // subsequent fetch should return error, subscription is already drained + msgs, err = sub.Fetch(10, nats.MaxWait(100*time.Millisecond)) + if !errors.Is(err, nats.ErrSubscriptionClosed) { + t.Fatalf("Expected error: %s; got: %s", nats.ErrSubscriptionClosed, err) + } +} + func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) { s := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, s) @@ -1761,6 +1819,55 @@ func TestPullSubscribeFetchBatch(t *testing.T) { t.Errorf("Expected error: %s; got: %s", nats.ErrNoDeadlineContext, err) } }) + + t.Run("close subscription", func(t *testing.T) { + defer js.PurgeStream("TEST") + sub, err := js.PullSubscribe("foo", "") + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for i := 0; i < 100; i++ { + if _, err := js.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + } + // fill buffer with messages + cinfo, err := sub.ConsumerInfo() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + nextSubject := fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.TEST.%s", cinfo.Name) + replySubject := strings.Replace(sub.Subject, "*", "abc", 1) + payload := `{"batch":10,"no_wait":true}` + if err := nc.PublishRequest(nextSubject, replySubject, []byte(payload)); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + time.Sleep(100 * time.Millisecond) + + // now drain the subscription, messages should be in the buffer + sub.Drain() + res, err := sub.FetchBatch(100) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + msgs := make([]*nats.Msg, 0) + for msg := range res.Messages() { + msgs = append(msgs, msg) + msg.Ack() + } + if res.Error() != nil { + t.Fatalf("Unexpected error: %s", res.Error()) + } + if len(msgs) != 10 { + t.Fatalf("Expected %d messages; got: %d", 10, len(msgs)) + } + + // subsequent fetch should return error, subscription is already drained + res, err = sub.FetchBatch(10, nats.MaxWait(100*time.Millisecond)) + if !errors.Is(err, nats.ErrSubscriptionClosed) { + t.Fatalf("Expected error: %s; got: %s", nats.ErrSubscriptionClosed, err) + } + }) } func TestPullSubscribeConsumerDeleted(t *testing.T) {