diff --git a/jetstream/publish.go b/jetstream/publish.go index f41b06fd1..1a3a4fdbe 100644 --- a/jetstream/publish.go +++ b/jetstream/publish.go @@ -378,6 +378,7 @@ func (js *jetStream) handleAsyncReply(m *nats.Msg) { cb := js.publisher.asyncPublisherOpts.aecb js.publisher.Unlock() if cb != nil { + paf.msg.Reply = "" cb(js, paf.msg, err) } } @@ -388,6 +389,12 @@ func (js *jetStream) handleAsyncReply(m *nats.Msg) { paf.retries++ paf.msg.Reply = m.Subject time.AfterFunc(paf.retryWait, func() { + js.publisher.Lock() + paf := js.getPAF(id) + js.publisher.Unlock() + if paf == nil { + return + } _, err := js.PublishMsgAsync(paf.msg, func(po *pubOpts) error { po.pafRetry = paf return nil @@ -453,10 +460,21 @@ func (js *jetStream) resetPendingAcksOnReconnect() { return } js.publisher.Lock() - for _, paf := range js.publisher.acks { + errCb := js.publisher.asyncPublisherOpts.aecb + for id, paf := range js.publisher.acks { paf.err = nats.ErrDisconnected + if paf.errCh != nil { + paf.errCh <- paf.err + } + if errCb != nil { + js.publisher.Unlock() + // clear reply subject so that new one is created on republish + paf.msg.Reply = "" + errCb(js, paf.msg, nats.ErrDisconnected) + js.publisher.Lock() + } + delete(js.publisher.acks, id) } - js.publisher.acks = nil if js.publisher.doneCh != nil { close(js.publisher.doneCh) js.publisher.doneCh = nil diff --git a/jetstream/test/publish_test.go b/jetstream/test/publish_test.go index a7ba8c31a..2233a16ce 100644 --- a/jetstream/test/publish_test.go +++ b/jetstream/test/publish_test.go @@ -16,8 +16,10 @@ package test import ( "context" "errors" + "fmt" "os" "reflect" + "sync" "testing" "time" @@ -1330,7 +1332,6 @@ func TestPublishMsgAsyncWithPendingMsgs(t *testing.T) { func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { s := RunBasicJetStreamServer() - defer shutdownJSServerAndRemoveStorage(t, s) nc, err := nats.Connect(s.ClientURL()) if err != nil { @@ -1352,6 +1353,7 @@ func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { errs := make(chan error, 1) done := make(chan struct{}, 1) acks := make(chan jetstream.PubAckFuture, 100) + wg := sync.WaitGroup{} go func() { for i := 0; i < 100; i++ { if ack, err := js.PublishAsync("FOO.A", []byte("hello")); err != nil { @@ -1360,6 +1362,7 @@ func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { } else { acks <- ack } + wg.Add(1) } close(acks) done <- struct{}{} @@ -1371,28 +1374,32 @@ func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { case <-time.After(5 * time.Second): t.Fatalf("Did not receive completion signal") } - s.Shutdown() - time.Sleep(100 * time.Millisecond) - if pending := js.PublishAsyncPending(); pending != 0 { - t.Fatalf("Expected no pending messages after server shutdown; got: %d", pending) + for ack := range acks { + go func(paf jetstream.PubAckFuture) { + select { + case <-paf.Ok(): + case err := <-paf.Err(): + if !errors.Is(err, nats.ErrDisconnected) && !errors.Is(err, nats.ErrNoResponders) { + errs <- fmt.Errorf("Expected error: %v or %v; got: %v", nats.ErrDisconnected, nats.ErrNoResponders, err) + } + case <-time.After(5 * time.Second): + errs <- fmt.Errorf("Did not receive completion signal") + } + wg.Done() + }(ack) } - s = RunBasicJetStreamServer() + s = restartBasicJSServer(t, s) defer shutdownJSServerAndRemoveStorage(t, s) - for ack := range acks { - select { - case <-ack.Ok(): - case err := <-ack.Err(): - if !errors.Is(err, nats.ErrDisconnected) && !errors.Is(err, nats.ErrNoResponders) { - t.Fatalf("Expected error: %v or %v; got: %v", nats.ErrDisconnected, nats.ErrNoResponders, err) - } - case <-time.After(5 * time.Second): - t.Fatalf("Did not receive completion signal") - } + wg.Wait() + select { + case err := <-errs: + t.Fatalf("Unexpected error: %v", err) + default: } } -func TestAsyncPublishRetry(t *testing.T) { +func TestPublishAsyncRetry(t *testing.T) { tests := []struct { name string pubOpts []jetstream.PublishOpt @@ -1472,3 +1479,69 @@ func TestAsyncPublishRetry(t *testing.T) { }) } } + +func TestPublishAsyncRetryInErrHandler(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + streamCreated := make(chan struct{}) + errCB := func(js jetstream.JetStream, m *nats.Msg, e error) { + <-streamCreated + _, err := js.PublishMsgAsync(m) + if err != nil { + t.Fatalf("Unexpected error when republishing: %v", err) + } + } + + js, err := jetstream.New(nc, jetstream.WithPublishAsyncErrHandler(errCB)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + errs := make(chan error, 1) + done := make(chan struct{}, 1) + go func() { + for i := 0; i < 10; i++ { + if _, err := js.PublishAsync("FOO.A", []byte("hello"), jetstream.WithRetryAttempts(0)); err != nil { + errs <- err + return + } + } + done <- struct{}{} + }() + select { + case <-done: + case err := <-errs: + t.Fatalf("Unexpected error during publish: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + stream, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + close(streamCreated) + select { + case <-js.PublishAsyncComplete(): + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + + info, err := stream.Info(context.Background()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if info.State.Msgs != 10 { + t.Fatalf("Expected 10 messages in the stream; got: %d", info.State.Msgs) + } +} diff --git a/js.go b/js.go index 97ce71d26..f5250348c 100644 --- a/js.go +++ b/js.go @@ -712,10 +712,20 @@ func (js *js) resetPendingAcksOnReconnect() { return } js.mu.Lock() - for _, paf := range js.pafs { + errCb := js.opts.aecb + for id, paf := range js.pafs { paf.err = ErrDisconnected + if paf.errCh != nil { + paf.errCh <- paf.err + } + if errCb != nil { + // clear reply subject so that new one is created on republish + js.mu.Unlock() + errCb(js, paf.msg, ErrDisconnected) + js.mu.Lock() + } + delete(js.pafs, id) } - js.pafs = nil if js.dch != nil { close(js.dch) js.dch = nil diff --git a/test/js_test.go b/test/js_test.go index 279736c16..a77c3993c 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -8076,6 +8076,70 @@ func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { } } +func TestPublishAsyncRetryInErrHandler(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + streamCreated := make(chan struct{}) + errCB := func(js nats.JetStream, m *nats.Msg, e error) { + <-streamCreated + _, err := js.PublishMsgAsync(m) + if err != nil { + t.Fatalf("Unexpected error when republishing: %v", err) + } + } + + js, err := nc.JetStream(nats.PublishAsyncErrHandler(errCB)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + errs := make(chan error, 1) + done := make(chan struct{}, 1) + go func() { + for i := 0; i < 10; i++ { + if _, err := js.PublishAsync("FOO.A", []byte("hello")); err != nil { + errs <- err + return + } + } + done <- struct{}{} + }() + select { + case <-done: + case err := <-errs: + t.Fatalf("Unexpected error during publish: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + _, err = js.AddStream(&nats.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + close(streamCreated) + select { + case <-js.PublishAsyncComplete(): + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + + info, err := js.StreamInfo("foo") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if info.State.Msgs != 10 { + t.Fatalf("Expected 10 messages in the stream; got: %d", info.State.Msgs) + } +} + func TestJetStreamPublishAsyncPerf(t *testing.T) { // Comment out below to run this benchmark. t.SkipNow()