diff --git a/go_test.mod b/go_test.mod index f5b731dd0..20e1ab730 100644 --- a/go_test.mod +++ b/go_test.mod @@ -1,23 +1,25 @@ module github.com/nats-io/nats.go -go 1.19 +go 1.21 + +toolchain go1.22.5 require ( github.com/golang/protobuf v1.4.2 - github.com/klauspost/compress v1.17.8 + github.com/klauspost/compress v1.17.9 github.com/nats-io/jwt v1.2.2 - github.com/nats-io/nats-server/v2 v2.10.16 + github.com/nats-io/nats-server/v2 v2.10.17 github.com/nats-io/nkeys v0.4.7 github.com/nats-io/nuid v1.0.1 go.uber.org/goleak v1.3.0 - golang.org/x/text v0.15.0 + golang.org/x/text v0.16.0 google.golang.org/protobuf v1.23.0 ) require ( github.com/minio/highwayhash v1.0.2 // indirect github.com/nats-io/jwt/v2 v2.5.7 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/sys v0.20.0 // indirect + golang.org/x/crypto v0.24.0 // indirect + golang.org/x/sys v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect ) diff --git a/go_test.sum b/go_test.sum index f89d755ba..df0ef6d7c 100644 --- a/go_test.sum +++ b/go_test.sum @@ -1,4 +1,5 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= @@ -10,38 +11,40 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= -github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= github.com/nats-io/jwt v1.2.2 h1:w3GMTO969dFg+UOKTmmyuu7IGdusK+7Ytlt//OYH/uU= github.com/nats-io/jwt v1.2.2/go.mod h1:/xX356yQA6LuXI9xWW7mZNpxgF2mBmGecH+Fj34sP5Q= github.com/nats-io/jwt/v2 v2.5.7 h1:j5lH1fUXCnJnY8SsQeB/a/z9Azgu2bYIDvtPVNdxe2c= github.com/nats-io/jwt/v2 v2.5.7/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= -github.com/nats-io/nats-server/v2 v2.10.16 h1:2jXaiydp5oB/nAx/Ytf9fdCi9QN6ItIc9eehX8kwVV0= -github.com/nats-io/nats-server/v2 v2.10.16/go.mod h1:Pksi38H2+6xLe1vQx0/EA4bzetM0NqyIHcIbmgXSkIU= +github.com/nats-io/nats-server/v2 v2.10.17 h1:PTVObNBD3TZSNUDgzFb1qQsQX4mOgFmOuG9vhT+KBUY= +github.com/nats-io/nats-server/v2 v2.10.17/go.mod h1:5OUyc4zg42s/p2i92zbbqXvUNsbF0ivdTLKshVMn2YQ= github.com/nats-io/nkeys v0.2.0/go.mod h1:XdZpAbhgyyODYqjTawOnIOI7VlbKSarI9Gfy1tqEu/s= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= @@ -54,3 +57,4 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/syncx/map.go b/internal/syncx/map.go new file mode 100644 index 000000000..d2278e62a --- /dev/null +++ b/internal/syncx/map.go @@ -0,0 +1,73 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import "sync" + +// Map is a type-safe wrapper around sync.Map. +// It is safe for concurrent use. +// The zero value of Map is an empty map ready to use. +type Map[K comparable, V any] struct { + m sync.Map +} + +func (m *Map[K, V]) Load(key K) (V, bool) { + v, ok := m.m.Load(key) + if !ok { + var empty V + return empty, false + } + return v.(V), true +} + +func (m *Map[K, V]) Store(key K, value V) { + m.m.Store(key, value) +} + +func (m *Map[K, V]) Delete(key K) { + m.m.Delete(key) +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value any) bool { + return f(key.(K), value.(V)) + }) +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (V, bool) { + v, loaded := m.m.LoadOrStore(key, value) + return v.(V), loaded +} + +func (m *Map[K, V]) LoadAndDelete(key K) (V, bool) { + v, ok := m.m.LoadAndDelete(key) + if !ok { + var empty V + return empty, false + } + return v.(V), true +} + +func (m *Map[K, V]) CompareAndSwap(key K, old, new V) bool { + return m.m.CompareAndSwap(key, old, new) +} + +func (m *Map[K, V]) CompareAndDelete(key K, value V) bool { + return m.m.CompareAndDelete(key, value) +} + +func (m *Map[K, V]) Swap(key K, value V) (V, bool) { + previous, loaded := m.m.Swap(key, value) + return previous.(V), loaded +} diff --git a/internal/syncx/map_test.go b/internal/syncx/map_test.go new file mode 100644 index 000000000..df34b2f2f --- /dev/null +++ b/internal/syncx/map_test.go @@ -0,0 +1,152 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import ( + "testing" +) + +func TestMapLoad(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + v, ok := m.Load(1) + if !ok || v != "one" { + t.Errorf("Load(1) = %v, %v; want 'one', true", v, ok) + } + + v, ok = m.Load(2) + if ok || v != "" { + t.Errorf("Load(2) = %v, %v; want '', false", v, ok) + } +} + +func TestMapStore(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + v, ok := m.Load(1) + if !ok || v != "one" { + t.Errorf("Load(1) after Store(1, 'one') = %v, %v; want 'one', true", v, ok) + } +} + +func TestMapDelete(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + m.Delete(1) + + v, ok := m.Load(1) + if ok || v != "" { + t.Errorf("Load(1) after Delete(1) = %v, %v; want '', false", v, ok) + } +} + +func TestMapRange(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + m.Store(2, "two") + + var keys []int + var values []string + m.Range(func(key int, value string) bool { + keys = append(keys, key) + values = append(values, value) + return true + }) + + if len(keys) != 2 || len(values) != 2 { + t.Errorf("Range() keys = %v, values = %v; want 2 keys and 2 values", keys, values) + } +} + +func TestMapLoadOrStore(t *testing.T) { + var m Map[int, string] + + v, loaded := m.LoadOrStore(1, "one") + if loaded || v != "one" { + t.Errorf("LoadOrStore(1, 'one') = %v, %v; want 'one', false", v, loaded) + } + + v, loaded = m.LoadOrStore(1, "uno") + if !loaded || v != "one" { + t.Errorf("LoadOrStore(1, 'uno') = %v, %v; want 'one', true", v, loaded) + } +} + +func TestMapLoadAndDelete(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + v, ok := m.LoadAndDelete(1) + if !ok || v != "one" { + t.Errorf("LoadAndDelete(1) = %v, %v; want 'one', true", v, ok) + } + + v, ok = m.Load(1) + if ok || v != "" { + t.Errorf("Load(1) after LoadAndDelete(1) = %v, %v; want '', false", v, ok) + } + + // Test that LoadAndDelete on a missing key returns the zero value. + v, ok = m.LoadAndDelete(2) + if ok || v != "" { + t.Errorf("LoadAndDelete(2) = %v, %v; want '', false", v, ok) + } +} + +func TestMapCompareAndSwap(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + ok := m.CompareAndSwap(1, "one", "uno") + if !ok { + t.Errorf("CompareAndSwap(1, 'one', 'uno') = false; want true") + } + + v, _ := m.Load(1) + if v != "uno" { + t.Errorf("Load(1) after CompareAndSwap = %v; want 'uno'", v) + } +} + +func TestMapCompareAndDelete(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + ok := m.CompareAndDelete(1, "one") + if !ok { + t.Errorf("CompareAndDelete(1, 'one') = false; want true") + } + + v, _ := m.Load(1) + if v != "" { + t.Errorf("Load(1) after CompareAndDelete = %v; want ''", v) + } +} + +func TestMapSwap(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + v, loaded := m.Swap(1, "uno") + if !loaded || v != "one" { + t.Errorf("Swap(1, 'uno') = %v, %v; want 'one', true", v, loaded) + } + + v, _ = m.Load(1) + if v != "uno" { + t.Errorf("Load(1) after Swap = %v; want 'uno'", v) + } +} diff --git a/jetstream/consumer.go b/jetstream/consumer.go index aa9003f3d..ee48a1ec3 100644 --- a/jetstream/consumer.go +++ b/jetstream/consumer.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/nats-io/nats.go/internal/syncx" "github.com/nats-io/nuid" ) @@ -233,12 +234,12 @@ func upsertConsumer(ctx context.Context, js *jetStream, stream string, cfg Consu } return &pullConsumer{ - jetStream: js, - stream: stream, - name: resp.Name, - durable: cfg.Durable != "", - info: resp.ConsumerInfo, - subscriptions: make(map[string]*pullSubscription), + jetStream: js, + stream: stream, + name: resp.Name, + durable: cfg.Durable != "", + info: resp.ConsumerInfo, + subs: syncx.Map[string, *pullSubscription]{}, }, nil } @@ -285,12 +286,12 @@ func getConsumer(ctx context.Context, js *jetStream, stream, name string) (Consu } cons := &pullConsumer{ - jetStream: js, - stream: stream, - name: name, - durable: resp.Config.Durable != "", - info: resp.ConsumerInfo, - subscriptions: make(map[string]*pullSubscription, 0), + jetStream: js, + stream: stream, + name: name, + durable: resp.Config.Durable != "", + info: resp.ConsumerInfo, + subs: syncx.Map[string, *pullSubscription]{}, } return cons, nil diff --git a/jetstream/jetstream_test.go b/jetstream/jetstream_test.go index 878d361e4..58f906423 100644 --- a/jetstream/jetstream_test.go +++ b/jetstream/jetstream_test.go @@ -276,9 +276,11 @@ func TestRetryWithBackoff(t *testing.T) { } func TestPullConsumer_checkPending(t *testing.T) { + tests := []struct { name string givenSub *pullSubscription + fetchInProgress bool shouldSend bool expectedPullRequest *pullRequest }{ @@ -292,7 +294,6 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdMessages: 5, MaxMessages: 10, }, - fetchInProgress: 0, }, shouldSend: false, }, @@ -307,7 +308,6 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdMessages: 5, MaxMessages: 10, }, - fetchInProgress: 0, }, shouldSend: true, expectedPullRequest: &pullRequest{ @@ -325,9 +325,9 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdMessages: 5, MaxMessages: 10, }, - fetchInProgress: 1, }, - shouldSend: false, + fetchInProgress: true, + shouldSend: false, }, { name: "pending bytes below threshold, send pull request", @@ -341,7 +341,6 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdBytes: 500, MaxBytes: 1000, }, - fetchInProgress: 0, }, shouldSend: true, expectedPullRequest: &pullRequest{ @@ -359,7 +358,6 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdBytes: 500, MaxBytes: 1000, }, - fetchInProgress: 0, }, shouldSend: false, }, @@ -373,9 +371,9 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdBytes: 500, MaxBytes: 1000, }, - fetchInProgress: 1, }, - shouldSend: false, + fetchInProgress: true, + shouldSend: false, }, { name: "StopAfter set, pending msgs below StopAfter, send pull request", @@ -388,8 +386,7 @@ func TestPullConsumer_checkPending(t *testing.T) { MaxMessages: 10, StopAfter: 8, }, - fetchInProgress: 0, - delivered: 2, + delivered: 2, }, shouldSend: true, expectedPullRequest: &pullRequest{ @@ -408,8 +405,7 @@ func TestPullConsumer_checkPending(t *testing.T) { MaxMessages: 10, StopAfter: 6, }, - fetchInProgress: 0, - delivered: 0, + delivered: 0, }, shouldSend: false, }, @@ -419,6 +415,9 @@ func TestPullConsumer_checkPending(t *testing.T) { t.Run(test.name, func(t *testing.T) { prChan := make(chan *pullRequest, 1) test.givenSub.fetchNext = prChan + if test.fetchInProgress { + test.givenSub.fetchInProgress.Store(1) + } errs := make(chan error, 1) ok := make(chan struct{}, 1) go func() { diff --git a/jetstream/ordered.go b/jetstream/ordered.go index 998b83dc3..469624477 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -32,6 +32,7 @@ type ( cfg *OrderedConsumerConfig stream string currentConsumer *pullConsumer + currentSub ConsumeContext cursor cursor namePrefix string serial int @@ -116,19 +117,11 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt } meta, err := msg.Metadata() if err != nil { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return - } - c.errHandler(serial)(sub, err) + c.errHandler(serial)(c.currentSub, err) return } dseq := meta.Sequence.Consumer if dseq != c.cursor.deliverSeq+1 { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return - } c.errHandler(serial)(sub, errOrderedSequenceMismatch) return } @@ -138,21 +131,18 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt } } - _, err = c.currentConsumer.Consume(internalHandler(c.serial), opts...) + cc, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...) if err != nil { return nil, err } + c.currentSub = cc go func() { for { select { case <-c.doReset: if err := c.reset(); err != nil { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return - } - c.errHandler(c.serial)(sub, err) + c.errHandler(c.serial)(c.currentSub, err) } if c.withStopAfter { select { @@ -175,12 +165,12 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt if c.withStopAfter { opts = append(opts, consumeStopAfterNotify(c.stopAfter, c.stopAfterMsgsLeft)) } - if _, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...); err != nil { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return - } - c.errHandler(c.serial)(sub, err) + if cc, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...); err != nil { + c.errHandler(c.serial)(cc, err) + } else { + c.Lock() + c.currentSub = cc + c.Unlock() } case <-sub.done: return @@ -250,10 +240,11 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er if c.stopAfter > 0 { opts = append(opts, messagesStopAfterNotify(c.stopAfter, c.stopAfterMsgsLeft)) } - _, err = c.currentConsumer.Messages(opts...) + cc, err := c.currentConsumer.Messages(opts...) if err != nil { return nil, err } + c.currentSub = cc sub := &orderedSubscription{ consumer: c, @@ -267,12 +258,7 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er func (s *orderedSubscription) Next() (Msg, error) { for { - currentConsumer := s.consumer.currentConsumer - sub, ok := currentConsumer.getSubscription("") - if !ok { - return nil, ErrMsgIteratorClosed - } - msg, err := sub.Next() + msg, err := s.consumer.currentSub.(*pullSubscription).Next() if err != nil { if errors.Is(err, ErrMsgIteratorClosed) { s.Stop() @@ -292,10 +278,11 @@ func (s *orderedSubscription) Next() (Msg, error) { if err := s.consumer.reset(); err != nil { return nil, err } - _, err := s.consumer.currentConsumer.Messages(s.opts...) + cc, err := s.consumer.currentConsumer.Messages(s.opts...) if err != nil { return nil, err } + s.consumer.currentSub = cc continue } @@ -312,10 +299,11 @@ func (s *orderedSubscription) Next() (Msg, error) { if err := s.consumer.reset(); err != nil { return nil, err } - _, err := s.consumer.currentConsumer.Messages(s.opts...) + cc, err := s.consumer.currentConsumer.Messages(s.opts...) if err != nil { return nil, err } + s.consumer.currentSub = cc continue } s.consumer.cursor.deliverSeq = dseq @@ -328,13 +316,9 @@ func (s *orderedSubscription) Stop() { if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return } - sub, ok := s.consumer.currentConsumer.getSubscription("") - if !ok { - return - } - s.consumer.currentConsumer.Lock() - defer s.consumer.currentConsumer.Unlock() - sub.Stop() + s.consumer.Lock() + defer s.consumer.Unlock() + s.consumer.currentSub.Stop() close(s.done) } @@ -342,13 +326,9 @@ func (s *orderedSubscription) Drain() { if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return } - sub, ok := s.consumer.currentConsumer.getSubscription("") - if !ok { - return - } s.consumer.currentConsumer.Lock() defer s.consumer.currentConsumer.Unlock() - sub.Drain() + s.consumer.currentSub.Drain() close(s.done) } @@ -495,10 +475,9 @@ func (c *orderedConsumer) reset() error { defer c.Unlock() defer atomic.StoreUint32(&c.resetInProgress, 0) if c.currentConsumer != nil { - sub, ok := c.currentConsumer.getSubscription("") c.currentConsumer.Lock() - if ok { - sub.Stop() + if c.currentSub != nil { + c.currentSub.Stop() } consName := c.currentConsumer.CachedInfo().Name c.currentConsumer.Unlock() diff --git a/jetstream/pull.go b/jetstream/pull.go index 001a0d183..a510c7c8c 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -23,6 +23,7 @@ import ( "time" "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/internal/syncx" "github.com/nats-io/nuid" ) @@ -75,12 +76,12 @@ type ( pullConsumer struct { sync.Mutex - jetStream *jetStream - stream string - durable bool - name string - info *ConsumerInfo - subscriptions map[string]*pullSubscription + jetStream *jetStream + stream string + durable bool + name string + info *ConsumerInfo + subs syncx.Map[string, *pullSubscription] } pullRequest struct { @@ -116,9 +117,9 @@ type ( errs chan error pending pendingMsgs hbMonitor *hbMonitor - fetchInProgress uint32 - closed uint32 - draining uint32 + fetchInProgress atomic.Uint32 + closed atomic.Uint32 + draining atomic.Uint32 done chan struct{} connStatusChanged chan nats.Status fetchNext chan *pullRequest @@ -181,12 +182,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name)) - // for single consume, use empty string as id - // this is useful for ordered consumer, where only a single subscription is valid - var consumeID string - if len(p.subscriptions) > 0 { - consumeID = nuid.Next() - } + consumeID := nuid.Next() sub := &pullSubscription{ id: consumeID, consumer: p, @@ -199,7 +195,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( sub.hbMonitor = sub.scheduleHeartbeatCheck(consumeOpts.Heartbeat) - p.subscriptions[sub.id] = sub + p.subs.Store(sub.id, sub) p.Unlock() internalHandler := func(msg *nats.Msg) { @@ -232,7 +228,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( sub.Unlock() if err != nil { - if atomic.LoadUint32(&sub.closed) == 1 { + if sub.closed.Load() == 1 { return } if sub.consumeOpts.ErrHandler != nil { @@ -259,10 +255,8 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( } sub.subscription.SetClosedHandler(func(sid string) func(string) { return func(subject string) { - p.Lock() - defer p.Unlock() - delete(p.subscriptions, sid) - atomic.CompareAndSwapUint32(&sub.draining, 1, 0) + p.subs.Delete(sid) + sub.draining.CompareAndSwap(1, 0) } }(sub.id)) @@ -286,7 +280,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( go func() { isConnected := true for { - if atomic.LoadUint32(&sub.closed) == 1 { + if sub.closed.Load() == 1 { return } select { @@ -383,7 +377,7 @@ func (s *pullSubscription) incrementDeliveredMsgs() { func (s *pullSubscription) checkPending() { if (s.pending.msgCount < s.consumeOpts.ThresholdMessages || (s.pending.byteCount < s.consumeOpts.ThresholdBytes && s.consumeOpts.MaxBytes != 0)) && - atomic.LoadUint32(&s.fetchInProgress) == 0 { + s.fetchInProgress.Load() == 0 { var batchSize, maxBytes int if s.consumeOpts.MaxBytes == 0 { @@ -427,12 +421,7 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error msgs := make(chan *nats.Msg, consumeOpts.MaxMessages) - // for single consume, use empty string as id - // this is useful for ordered consumer, where only a single subscription is valid - var consumeID string - if len(p.subscriptions) > 0 { - consumeID = nuid.Next() - } + consumeID := nuid.Next() sub := &pullSubscription{ id: consumeID, consumer: p, @@ -451,20 +440,18 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error } sub.subscription.SetClosedHandler(func(sid string) func(string) { return func(subject string) { - p.Lock() - defer p.Unlock() - if atomic.LoadUint32(&sub.draining) != 1 { + if sub.draining.Load() != 1 { // if we're not draining, subscription can be closed as soon // as closed handler is called // otherwise, we need to wait until all messages are drained // in Next - delete(p.subscriptions, sid) + p.subs.Delete(sid) } close(msgs) } }(sub.id)) - p.subscriptions[sub.id] = sub + p.subs.Store(sub.id, sub) p.Unlock() go sub.pullMessages(subject) @@ -502,8 +489,8 @@ var ( func (s *pullSubscription) Next() (Msg, error) { s.Lock() defer s.Unlock() - drainMode := atomic.LoadUint32(&s.draining) == 1 - closed := atomic.LoadUint32(&s.closed) == 1 + drainMode := s.draining.Load() == 1 + closed := s.closed.Load() == 1 if closed && !drainMode { return nil, ErrMsgIteratorClosed } @@ -526,8 +513,8 @@ func (s *pullSubscription) Next() (Msg, error) { case msg, ok := <-s.msgs: if !ok { // if msgs channel is closed, it means that subscription was either drained or stopped - delete(s.consumer.subscriptions, s.id) - atomic.CompareAndSwapUint32(&s.draining, 1, 0) + s.consumer.subs.Delete(s.id) + s.draining.CompareAndSwap(1, 0) return nil, ErrMsgIteratorClosed } if hbMonitor != nil { @@ -630,7 +617,7 @@ func (hb *hbMonitor) Reset(dur time.Duration) { // Next after calling Stop will return ErrMsgIteratorClosed error. // All messages that are already in the buffer are discarded. func (s *pullSubscription) Stop() { - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + if !s.closed.CompareAndSwap(0, 1) { return } close(s.done) @@ -648,10 +635,10 @@ func (s *pullSubscription) Stop() { // subsequent calls to Next. After the buffer is drained, Next will // return ErrMsgIteratorClosed error. func (s *pullSubscription) Drain() { - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + if !s.closed.CompareAndSwap(0, 1) { return } - atomic.StoreUint32(&s.draining, 1) + s.draining.Store(1) close(s.done) if s.consumeOpts.stopAfterMsgsLeft != nil { if s.delivered >= s.consumeOpts.StopAfter { @@ -840,7 +827,7 @@ func (s *pullSubscription) pullMessages(subject string) { for { select { case req := <-s.fetchNext: - atomic.StoreUint32(&s.fetchInProgress, 1) + s.fetchInProgress.Store(1) if err := s.pull(req, subject); err != nil { if errors.Is(err, ErrMsgIteratorClosed) { @@ -849,7 +836,7 @@ func (s *pullSubscription) pullMessages(subject string) { } s.errs <- err } - atomic.StoreUint32(&s.fetchInProgress, 0) + s.fetchInProgress.Store(0) case <-s.done: s.cleanup() return @@ -880,13 +867,13 @@ func (s *pullSubscription) cleanup() { if s.hbMonitor != nil { s.hbMonitor.Stop() } - drainMode := atomic.LoadUint32(&s.draining) == 1 + drainMode := s.draining.Load() == 1 if drainMode { s.subscription.Drain() } else { s.subscription.Unsubscribe() } - atomic.StoreUint32(&s.closed, 1) + s.closed.Store(1) } // pull sends a pull request to the server and waits for messages using a subscription from [pullSubscription]. @@ -894,7 +881,7 @@ func (s *pullSubscription) cleanup() { func (s *pullSubscription) pull(req *pullRequest, subject string) error { s.consumer.Lock() defer s.consumer.Unlock() - if atomic.LoadUint32(&s.closed) == 1 { + if s.closed.Load() == 1 { return ErrMsgIteratorClosed } if req.Batch < 1 { @@ -994,10 +981,3 @@ func (consumeOpts *consumeOpts) setDefaults(ordered bool) error { } return nil } - -func (c *pullConsumer) getSubscription(id string) (*pullSubscription, bool) { - c.Lock() - defer c.Unlock() - sub, ok := c.subscriptions[id] - return sub, ok -}