diff --git a/server/jetstream_cluster.go b/server/jetstream_cluster.go index c46f2900f96..961061830ba 100644 --- a/server/jetstream_cluster.go +++ b/server/jetstream_cluster.go @@ -2983,6 +2983,17 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco mset.clMu.Unlock() } + if mset.inflightSubjects != nil { + mset.clMu.Lock() + n := mset.inflightSubjects[subject] + if n > 1 { + mset.inflightSubjects[subject]-- + } else { + delete(mset.inflightSubjects, subject) + } + mset.clMu.Unlock() + } + if err != nil { if err == errLastSeqMismatch { @@ -7751,7 +7762,7 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ name, stype, store := mset.cfg.Name, mset.cfg.Storage, mset.store s, js, jsa, st, r, tierName, outq, node := mset.srv, mset.js, mset.jsa, mset.cfg.Storage, mset.cfg.Replicas, mset.tier, mset.outq, mset.node maxMsgSize, lseq := int(mset.cfg.MaxMsgSize), mset.lseq - interestPolicy, discard, maxMsgs, maxBytes := mset.cfg.Retention != LimitsPolicy, mset.cfg.Discard, mset.cfg.MaxMsgs, mset.cfg.MaxBytes + interestPolicy, discard, maxMsgs, maxBytes, discardNewPerSubject, maxMsgsPer := mset.cfg.Retention != LimitsPolicy, mset.cfg.Discard, mset.cfg.MaxMsgs, mset.cfg.MaxBytes, mset.cfg.DiscardNewPer, mset.cfg.MaxMsgsPer isLeader, isSealed, compressOK := mset.isLeader(), mset.cfg.Sealed, mset.compressOK mset.mu.RUnlock() @@ -7914,27 +7925,36 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ mset.clseq = lseq + mset.clfs } - // Check if we have an interest policy and discard new with max msgs or bytes. + // Check if we have an interest or working queue retention and discard new with max msgs or bytes. // We need to deny here otherwise it could succeed on some peers and not others // depending on consumer ack state. So we deny here, if we allow that means we know // it would succeed on every peer. - if interestPolicy && discard == DiscardNew && (maxMsgs > 0 || maxBytes > 0) { + if interestPolicy && discard == DiscardNew && (maxMsgs > 0 || maxBytes > 0 || (maxMsgsPer > 0 && discardNewPerSubject)) { // Track inflight. if mset.inflight == nil { mset.inflight = make(map[uint64]uint64) } + if stype == FileStorage { mset.inflight[mset.clseq] = fileStoreMsgSize(subject, hdr, msg) } else { mset.inflight[mset.clseq] = memStoreMsgSize(subject, hdr, msg) } + if mset.inflightSubjects == nil { + mset.inflightSubjects = make(map[string]uint64) + } + + mset.inflightSubjects[subject]++ + var state StreamState mset.store.FastState(&state) var err error - if maxMsgs > 0 && state.Msgs+uint64(len(mset.inflight)) > uint64(maxMsgs) { - err = ErrMaxMsgs + if maxMsgs > 0 { + if state.Msgs+uint64(len(mset.inflight)) > uint64(maxMsgs) { + err = ErrMaxMsgs + } } else if maxBytes > 0 { // TODO(dlc) - Could track this rollup independently. var bytesPending uint64 @@ -7944,9 +7964,24 @@ func (mset *stream) processClusteredInboundMsg(subject, reply string, hdr, msg [ if state.Bytes+bytesPending > uint64(maxBytes) { err = ErrMaxBytes } + } else if maxMsgsPer > 0 && discardNewPerSubject { + totals := mset.store.SubjectsTotals(subject) + total := totals[subject] + if (total + mset.inflightSubjects[subject]) > uint64(maxMsgsPer) { + err = ErrMaxMsgsPerSubject + } } + if err != nil { delete(mset.inflight, mset.clseq) + n := mset.inflightSubjects[subject] + + if n > 1 { + mset.inflightSubjects[subject] = n - 1 + } else { + delete(mset.inflightSubjects, subject) + } + mset.clMu.Unlock() if canRespond { var resp = &JSPubAckResponse{PubAck: &PubAck{Stream: name}} diff --git a/server/jetstream_cluster_4_test.go b/server/jetstream_cluster_4_test.go index e2f55d8cabc..fab1b3d3f81 100644 --- a/server/jetstream_cluster_4_test.go +++ b/server/jetstream_cluster_4_test.go @@ -252,13 +252,32 @@ func TestJetStreamClusterSourceWorkingQueueWithLimit(t *testing.T) { Sources: []*nats.StreamSource{{Name: "test"}}, Replicas: 3}) require_NoError(t, err) + _, err = js.AddStream(&nats.StreamConfig{Name: "wq3", MaxMsgsPerSubject: maxMsgs, Discard: nats.DiscardNew, DiscardNewPerSubject: true, Retention: nats.WorkQueuePolicy, + Sources: []*nats.StreamSource{{Name: "test"}}, Replicas: 3}) + require_NoError(t, err) + sendBatch := func(subject string, n int) { for i := 0; i < n; i++ { _, err = js.Publish(subject, []byte(fmt.Sprintf(msgPayloadFormat, i))) require_NoError(t, err) } } - // Populate each one. + + f := func(ss *nats.Subscription, done chan bool) { + for i := 0; i < totalMsgs; i++ { + m, err := ss.Fetch(1, nats.MaxWait(3*time.Second)) + require_NoError(t, err) + p, err := strconv.Atoi(string(m[0].Data)) + require_NoError(t, err) + require_Equal(t, p, i) + time.Sleep(11 * time.Millisecond) + err = m[0].Ack() + require_NoError(t, err) + } + done <- true + } + + // Populate the sourced stream. sendBatch("test", totalMsgs) checkFor(t, 3*time.Second, 250*time.Millisecond, func() error { @@ -279,27 +298,21 @@ func TestJetStreamClusterSourceWorkingQueueWithLimit(t *testing.T) { return nil }) + checkFor(t, 3*time.Second, 250*time.Millisecond, func() error { + si, err := js.StreamInfo("wq3") + require_NoError(t, err) + if si.State.Msgs != maxMsgs { + return fmt.Errorf("expected %d msgs on stream wq, got state: %+v", maxMsgs, si.State) + } + return nil + }) + _, err = js.AddConsumer("wq", &nats.ConsumerConfig{Durable: "wqc", FilterSubject: "test", AckPolicy: nats.AckExplicitPolicy}) require_NoError(t, err) ss1, err := js.PullSubscribe("test", "wqc", nats.Bind("wq", "wqc")) require_NoError(t, err) - // we must have at least one message on the transformed subject name (ie no timeout) - f := func(ss *nats.Subscription, done chan bool) { - for i := 0; i < totalMsgs; i++ { - m, err := ss.Fetch(1, nats.MaxWait(3*time.Second)) - require_NoError(t, err) - p, err := strconv.Atoi(string(m[0].Data)) - require_NoError(t, err) - require_Equal(t, p, i) - time.Sleep(11 * time.Millisecond) - err = m[0].Ack() - require_NoError(t, err) - } - done <- true - } - var doneChan1 = make(chan bool) go f(ss1, doneChan1) @@ -349,6 +362,34 @@ func TestJetStreamClusterSourceWorkingQueueWithLimit(t *testing.T) { case <-time.After(20 * time.Second): t.Fatalf("Did not receive completion signal") } + + _, err = js.AddConsumer("wq3", &nats.ConsumerConfig{Durable: "wqc", FilterSubject: "test", AckPolicy: nats.AckExplicitPolicy}) + require_NoError(t, err) + + ss3, err := js.PullSubscribe("test", "wqc", nats.Bind("wq3", "wqc")) + require_NoError(t, err) + + var doneChan3 = make(chan bool) + go f(ss3, doneChan3) + + checkFor(t, 10*time.Second, 250*time.Millisecond, func() error { + si, err := js.StreamInfo("wq3") + require_NoError(t, err) + if si.State.Msgs > 0 && si.State.Msgs <= maxMsgs { + return fmt.Errorf("expected 0 msgs on stream wq3, got: %d", si.State.Msgs) + } else if si.State.Msgs > maxMsgs { + t.Fatalf("got more than our %d message limit on stream wq3: %+v", maxMsgs, si.State) + } + + return nil + }) + + select { + case <-doneChan3: + ss3.Drain() + case <-time.After(10 * time.Second): + t.Fatalf("Did not receive completion signal") + } } func TestJetStreamClusterConsumerPauseViaConfig(t *testing.T) { diff --git a/server/stream.go b/server/stream.go index 5c1f131da85..eb4ec606480 100644 --- a/server/stream.go +++ b/server/stream.go @@ -286,20 +286,21 @@ type stream struct { // TODO(dlc) - Hide everything below behind two pointers. // Clustered mode. - sa *streamAssignment // What the meta controller uses to assign streams to peers. - node RaftNode // Our RAFT node for the stream's group. - catchup atomic.Bool // Used to signal we are in catchup mode. - catchups map[string]uint64 // The number of messages that need to be caught per peer. - syncSub *subscription // Internal subscription for sync messages (on "$JSC.SYNC"). - infoSub *subscription // Internal subscription for stream info requests. - clMu sync.Mutex // The mutex for clseq and clfs. - clseq uint64 // The current last seq being proposed to the NRG layer. - clfs uint64 // The count (offset) of the number of failed NRG sequences used to compute clseq. - inflight map[uint64]uint64 // Inflight message sizes per clseq. - lqsent time.Time // The time at which the last lost quorum advisory was sent. Used to rate limit. - uch chan struct{} // The channel to signal updates to the monitor routine. - compressOK bool // True if we can do message compression in RAFT and catchup logic - inMonitor bool // True if the monitor routine has been started. + sa *streamAssignment // What the meta controller uses to assign streams to peers. + node RaftNode // Our RAFT node for the stream's group. + catchup atomic.Bool // Used to signal we are in catchup mode. + catchups map[string]uint64 // The number of messages that need to be caught per peer. + syncSub *subscription // Internal subscription for sync messages (on "$JSC.SYNC"). + infoSub *subscription // Internal subscription for stream info requests. + clMu sync.Mutex // The mutex for clseq, clfs, inflight and inflightSubjects. + clseq uint64 // The current last seq being proposed to the NRG layer. + clfs uint64 // The count (offset) of the number of failed NRG sequences used to compute clseq. + inflight map[uint64]uint64 // Inflight message sizes per clseq. + inflightSubjects map[string]uint64 // Inflight number of messages per subject. + lqsent time.Time // The time at which the last lost quorum advisory was sent. Used to rate limit. + uch chan struct{} // The channel to signal updates to the monitor routine. + compressOK bool // True if we can do message compression in RAFT and catchup logic + inMonitor bool // True if the monitor routine has been started. // Direct get subscription. directSub *subscription @@ -3376,7 +3377,7 @@ func (mset *stream) processInboundSourceMsg(si *sourceInfo, m *inMsg) bool { // Can happen temporarily all the time during normal operations when the sourcing stream // is working queue/interest with a limit and discard new. // TODO - Improve sourcing to WQ with limit and new to use flow control rather than re-creating the consumer. - if errors.Is(err, ErrMaxMsgs) || errors.Is(err, ErrMaxBytes) { + if errors.Is(err, ErrMaxMsgs) || errors.Is(err, ErrMaxBytes) || errors.Is(err, ErrMaxMsgsPerSubject) { // Do not need to do a full retry that includes finding the last sequence in the stream // for that source. Just re-create starting with the seq we couldn't store instead. mset.mu.Lock()