diff --git a/common/membership/interfaces.go b/common/membership/interfaces.go index a1a225039dd..66d93d1b580 100644 --- a/common/membership/interfaces.go +++ b/common/membership/interfaces.go @@ -82,6 +82,9 @@ type ( ServiceResolver interface { // Lookup looks up the host that currently owns the resource identified by the given key. Lookup(key string) (HostInfo, error) + // LookupN looks n hosts that owns the resource identified by the given key, if n greater than total number + // of hosts total number of hosts will be returned + LookupN(key string, n int) []HostInfo // AddListener adds a listener which will get notified on the given channel whenever membership changes. AddListener(name string, notifyChannel chan<- *ChangedEvent) error // RemoveListener removes a listener for this service. diff --git a/common/membership/interfaces_mock.go b/common/membership/interfaces_mock.go index a64f4476743..a7c8811010d 100644 --- a/common/membership/interfaces_mock.go +++ b/common/membership/interfaces_mock.go @@ -181,6 +181,20 @@ func (mr *MockServiceResolverMockRecorder) Lookup(key interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lookup", reflect.TypeOf((*MockServiceResolver)(nil).Lookup), key) } +// LookupN mocks base method. +func (m *MockServiceResolver) LookupN(key string, n int) []HostInfo { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LookupN", key, n) + ret0, _ := ret[0].([]HostInfo) + return ret0 +} + +// LookupN indicates an expected call of LookupN. +func (mr *MockServiceResolverMockRecorder) LookupN(key, n interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupN", reflect.TypeOf((*MockServiceResolver)(nil).LookupN), key, n) +} + // MemberCount mocks base method. func (m *MockServiceResolver) MemberCount() int { m.ctrl.T.Helper() diff --git a/common/membership/ringpop/service_resolver.go b/common/membership/ringpop/service_resolver.go index f24c25b45bb..4dbedb135bf 100644 --- a/common/membership/ringpop/service_resolver.go +++ b/common/membership/ringpop/service_resolver.go @@ -45,6 +45,7 @@ import ( "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" "go.temporal.io/server/common/primitives" + "go.temporal.io/server/common/util" ) const ( @@ -62,10 +63,6 @@ const ( replicaPoints = 100 ) -type membershipManager interface { - AddListener() -} - type serviceResolver struct { service primitives.ServiceName port int @@ -154,6 +151,19 @@ func (r *serviceResolver) Lookup(key string) (membership.HostInfo, error) { return newHostInfo(addr, r.getLabelsMap()), nil } +func (r *serviceResolver) LookupN(key string, n int) []membership.HostInfo { + if n <= 0 { + return nil + } + addresses := r.ring().LookupN(key, n) + if len(addresses) == 0 { + r.RequestRefresh() + return nil + } + labels := r.getLabelsMap() + return util.MapSlice(addresses, func(address string) membership.HostInfo { return newHostInfo(address, labels) }) +} + func (r *serviceResolver) AddListener( name string, notifyChannel chan<- *membership.ChangedEvent, diff --git a/common/util/util.go b/common/util/util.go index e3375717b52..87d8b8d1680 100644 --- a/common/util/util.go +++ b/common/util/util.go @@ -117,6 +117,18 @@ func MapConcurrent[IN any, OUT any](input []IN, mapper func(IN) (OUT, error)) ([ return results, nil } +// MapSlice given slice xs []T and f(T) S produces slice []S by applying f to every element of xs +func MapSlice[T, S any](xs []T, f func(T) S) []S { + if xs == nil { + return nil + } + result := make([]S, len(xs)) + for i, s := range xs { + result[i] = f(s) + } + return result +} + // FilterSlice iterates over elements of a slice, returning a new slice of all elements predicate returns true for. func FilterSlice[T any](in []T, predicate func(T) bool) []T { var out []T @@ -128,8 +140,8 @@ func FilterSlice[T any](in []T, predicate func(T) bool) []T { return out } -// ReduceSlice reduces a slice using given reducer function and initial value. -func ReduceSlice[T any, A any](in []T, initializer A, reducer func(A, T) A) A { +// FoldSlice folds left a slice using given reducer function and initial value. +func FoldSlice[T any, A any](in []T, initializer A, reducer func(A, T) A) A { acc := initializer for _, val := range in { acc = reducer(acc, val) @@ -137,6 +149,19 @@ func ReduceSlice[T any, A any](in []T, initializer A, reducer func(A, T) A) A { return acc } +// RepeatSlice given slice and a number (n) produces a new slice containing original slice n times +// if n is non-positive will produce nil +func RepeatSlice[T any](xs []T, n int) []T { + if xs == nil || n <= 0 { + return nil + } + ys := make([]T, n*len(xs)) + for i := 0; i < n; i++ { + copy(ys[i*len(xs):], xs) + } + return ys +} + // Coalesce returns the first non-zero value of its arguments, or the zero value for the type // if all are zero. func Coalesce[T comparable](vals ...T) T { diff --git a/common/util/util_test.go b/common/util/util_test.go new file mode 100644 index 00000000000..fffff88734f --- /dev/null +++ b/common/util/util_test.go @@ -0,0 +1,91 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package util + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRepeatSlice(t *testing.T) { + t.Run("when input slice is nil should return nil", func(t *testing.T) { + got := RepeatSlice[int](nil, 5) + require.Nil(t, got, "RepeatSlice produced non-nil slice from nil input") + }) + t.Run("when input slice is empty should return empty", func(t *testing.T) { + empty := []int{} + got := RepeatSlice(empty, 5) + require.Len(t, got, 0, "RepeatSlice filled empty slice") + }) + t.Run("when requested repeat number equal 0 should return empty slice", func(t *testing.T) { + xs := []int{1, 2, 3, 4, 5} + got := RepeatSlice(xs, 0) + require.Len(t, got, 0, "RepeatSlice with repeat count 0 returned non-empty slice") + }) + t.Run("when requested repeat number is less than 0 should return empty slice", func(t *testing.T) { + xs := []int{1, 2, 3, 4, 5} + got := RepeatSlice(xs, -1) + require.Len(t, got, 0, "RepeatSlice with repeat count -1 returned non-empty slice") + }) + t.Run("when requested repeat number is 3 should return slice three times the input", func(t *testing.T) { + xs := []int{1, 2, 3, 4, 5} + got := RepeatSlice(xs, 3) + require.Len(t, got, len(xs)*3, "RepeatSlice produced slice of wrong length: expected %d got %d", len(xs)*3, len(got)) + for i, v := range got { + require.Equal(t, xs[i%len(xs)], v, "RepeatSlice wrong value in result: expected %d at index %d but got %d", xs[i%len(xs)], i, v) + } + }) + t.Run("should not change the input slice when truncating", func(t *testing.T) { + xs := []int{1, 2, 3, 4, 5} + _ = RepeatSlice(xs, 0) + require.Len(t, xs, 5, "Repeat slice truncated the original slice: expected {1, 2, 3, 4, 5}, got %v", xs) + }) + t.Run("should not change the input slice when replicating", func(t *testing.T) { + xs := []int{1, 2, 3, 4, 5} + _ = RepeatSlice(xs, 5) + require.Len(t, xs, 5, "Repeat slice changed the original slice: expected {1, 2, 3, 4, 5}, got %v", xs) + }) +} + +func TestMapSlice(t *testing.T) { + t.Run("when given nil as slice should return nil", func(t *testing.T) { + ys := MapSlice(nil, func(x int) uint32 { return uint32(x) }) + require.Nil(t, ys, "mapping over nil produced non nil got %v", ys) + }) + t.Run("when given an empty slice should return empty slice", func(t *testing.T) { + xs := []int{} + var ys []uint32 + ys = MapSlice(xs, func(x int) uint32 { return uint32(x) }) + require.Len(t, ys, 0, "mapping over empty slice produced non empty slice got %v", ys) + }) + t.Run("when given a slice and a function should apply function to every element of the original slice", func(t *testing.T) { + xs := []int{1, 2, 3, 4, 5} + ys := MapSlice(xs, func(x int) int { return x + 1 }) + for i, y := range ys { + require.Equal(t, xs[i]+1, y, "mapping over slice did not apply function expected {2, 3, 4, 5} got %v", ys) + } + }) +} diff --git a/service/frontend/task_reachability.go b/service/frontend/task_reachability.go index 091650857b2..2b3327f7454 100644 --- a/service/frontend/task_reachability.go +++ b/service/frontend/task_reachability.go @@ -187,7 +187,7 @@ func (wh *WorkflowHandler) getTaskQueueReachability(ctx context.Context, request reachableByNewWorkflows = true } else { // If the queue became versioned just recently, consider the unversioned build id reachable. - queueBecameVersionedAt := util.ReduceSlice(versionSets, &hlc.Clock{WallClock: math.MaxInt64}, func(c *hlc.Clock, set *persistencespb.CompatibleVersionSet) *hlc.Clock { + queueBecameVersionedAt := util.FoldSlice(versionSets, &hlc.Clock{WallClock: math.MaxInt64}, func(c *hlc.Clock, set *persistencespb.CompatibleVersionSet) *hlc.Clock { return hlc.Min(c, set.BecameDefaultTimestamp) }) reachableByNewWorkflows = time.Since(hlc.UTC(queueBecameVersionedAt)) < wh.config.ReachabilityQuerySetDurationSinceDefault() diff --git a/service/worker/pernamespaceworker.go b/service/worker/pernamespaceworker.go index f163e9c227b..17d036747cb 100644 --- a/service/worker/pernamespaceworker.go +++ b/service/worker/pernamespaceworker.go @@ -123,10 +123,17 @@ type ( StickyScheduleToStartTimeout string // parse into time.Duration StickyScheduleToStartTimeoutDuration time.Duration } + + workerAllocation struct { + Total int + Local int + } ) var ( errNoWorkerNeeded = errors.New("no worker needed") // sentinel value, not a real error + // errInvalidConfiguration indicates that the value provided by dynamic config is not legal + errInvalidConfiguration = errors.New("invalid dynamic configuration") ) func NewPerNamespaceWorkerManager(params perNamespaceWorkerManagerInitParams) *perNamespaceWorkerManager { @@ -260,24 +267,43 @@ func (wm *perNamespaceWorkerManager) removeWorker(ns *namespace.Namespace) { delete(wm.workers, ns.ID()) } -func (wm *perNamespaceWorkerManager) getWorkerMultiplicity(ns *namespace.Namespace) (int, int, error) { +func (wm *perNamespaceWorkerManager) getWorkerAllocation(ns *namespace.Namespace) (*workerAllocation, error) { + desiredWorkersCount, err := wm.getConfiguredWorkersCountFor(ns) + if err != nil { + return nil, err + } + if desiredWorkersCount == 0 { + return &workerAllocation{0, 0}, nil + } + localCount, err := wm.getLocallyDesiredWorkersCount(ns, desiredWorkersCount) + if err != nil { + return nil, err + } + return &workerAllocation{desiredWorkersCount, localCount}, nil +} + +func (wm *perNamespaceWorkerManager) getConfiguredWorkersCountFor(ns *namespace.Namespace) (int, error) { totalWorkers := wm.config.PerNamespaceWorkerCount(ns.Name().String()) - // This can result in fewer than the intended number of workers if totalWorkers > 1, because - // multiple lookups might land on the same node. To compensate, we increase the number of - // pollers in that case, but it would be better to try to spread them across our nodes. - // TODO: implement this properly using LookupN in serviceResolver - multiplicity := 0 - for i := 0; i < totalWorkers; i++ { - key := fmt.Sprintf("%s/%d", ns.ID().String(), i) - target, err := wm.serviceResolver.Lookup(key) - if err != nil { - return 0, 0, err - } - if target.Identity() == wm.self.Identity() { - multiplicity++ - } + if totalWorkers < 0 { + err := fmt.Errorf("%w namespace %s, workers count %d", errInvalidConfiguration, ns.Name(), totalWorkers) + return 0, err } - return multiplicity, totalWorkers, nil + return totalWorkers, nil +} + +func (wm *perNamespaceWorkerManager) getLocallyDesiredWorkersCount(ns *namespace.Namespace, desiredNumberOfWorkers int) (int, error) { + key := ns.ID().String() + availableHosts := wm.serviceResolver.LookupN(key, desiredNumberOfWorkers) + hostsCount := len(availableHosts) + if hostsCount == 0 { + return 0, membership.ErrInsufficientHosts + } + maxWorkersPerHost := desiredNumberOfWorkers/hostsCount + 1 + desiredDistribution := util.RepeatSlice(availableHosts, maxWorkersPerHost)[:desiredNumberOfWorkers] + + isLocal := func(info membership.HostInfo) bool { return info.Identity() == wm.self.Identity() } + result := len(util.FilterSlice(desiredDistribution, isLocal)) + return result, nil } func (wm *perNamespaceWorkerManager) getWorkerOptions(ns *namespace.Namespace) sdkWorkerOptions { @@ -393,18 +419,18 @@ func (w *perNamespaceWorker) tryRefresh(ns *namespace.Namespace) error { } // check if we are responsible for this namespace at all - multiplicity, totalWorkers, err := w.wm.getWorkerMultiplicity(ns) + workerAllocation, err := w.wm.getWorkerAllocation(ns) if err != nil { w.logger.Error("Failed to look up hosts", tag.Error(err)) // TODO: add metric also return err } - if multiplicity == 0 { + if workerAllocation.Local == 0 { // not ours, don't need a worker return errNoWorkerNeeded } // ensure this changes if multiplicity changes - componentSet += fmt.Sprintf(",%d", multiplicity) + componentSet += fmt.Sprintf(",%d", workerAllocation.Local) // get sdk worker options dcOptions := w.wm.getWorkerOptions(ns) @@ -424,7 +450,7 @@ func (w *perNamespaceWorker) tryRefresh(ns *namespace.Namespace) error { // create new one. note that even before startWorker returns, the worker may have started // and already called the fatal error handler. we need to set w.client+worker+componentSet // before releasing the lock to keep our state consistent. - client, worker, err := w.startWorker(ns, enabledComponents, multiplicity, totalWorkers, dcOptions) + client, worker, err := w.startWorker(ns, enabledComponents, workerAllocation, dcOptions) if err != nil { // TODO: add metric also return err @@ -439,8 +465,7 @@ func (w *perNamespaceWorker) tryRefresh(ns *namespace.Namespace) error { func (w *perNamespaceWorker) startWorker( ns *namespace.Namespace, components []workercommon.PerNSWorkerComponent, - multiplicity int, - totalWorkers int, + allocation *workerAllocation, dcOptions sdkWorkerOptions, ) (sdkclient.Client, sdkworker.Worker, error) { nsName := ns.Name().String() @@ -465,19 +490,19 @@ func (w *perNamespaceWorker) startWorker( sdkoptions.BackgroundActivityContext = headers.SetCallerInfo(context.Background(), headers.NewBackgroundCallerInfo(ns.Name().String())) sdkoptions.Identity = fmt.Sprintf("server-worker@%d@%s@%s", os.Getpid(), w.wm.hostName, nsName) - // increase these if we're supposed to run with more multiplicity - sdkoptions.MaxConcurrentWorkflowTaskPollers *= multiplicity - sdkoptions.MaxConcurrentActivityTaskPollers *= multiplicity - sdkoptions.MaxConcurrentLocalActivityExecutionSize *= multiplicity - sdkoptions.MaxConcurrentWorkflowTaskExecutionSize *= multiplicity - sdkoptions.MaxConcurrentActivityExecutionSize *= multiplicity + // increase these if we're supposed to run with more allocation + sdkoptions.MaxConcurrentWorkflowTaskPollers *= allocation.Local + sdkoptions.MaxConcurrentActivityTaskPollers *= allocation.Local + sdkoptions.MaxConcurrentLocalActivityExecutionSize *= allocation.Local + sdkoptions.MaxConcurrentWorkflowTaskExecutionSize *= allocation.Local + sdkoptions.MaxConcurrentActivityExecutionSize *= allocation.Local sdkoptions.OnFatalError = w.onFatalError // this should not block because the client already has server capabilities worker := w.wm.sdkClientFactory.NewWorker(client, primitives.PerNSWorkerTaskQueue, sdkoptions) details := workercommon.RegistrationDetails{ - TotalWorkers: totalWorkers, - Multiplicity: multiplicity, + TotalWorkers: allocation.Total, + Multiplicity: allocation.Local, } for _, cmp := range components { cmp.Register(worker, ns, details) diff --git a/service/worker/pernamespaceworker_test.go b/service/worker/pernamespaceworker_test.go index 14cf0ea04b8..0d4db28074f 100644 --- a/service/worker/pernamespaceworker_test.go +++ b/service/worker/pernamespaceworker_test.go @@ -85,7 +85,7 @@ func (s *perNsWorkerManagerSuite) SetupTest() { HostName: "self", Config: &Config{ PerNamespaceWorkerCount: func(ns string) int { - return max(1, map[string]int{"ns1": 1, "ns2": 2, "ns3": 3}[ns]) + return max(1, map[string]int{"ns1": 1, "ns2": 2, "ns3": 6}[ns]) }, PerNamespaceWorkerOptions: func(ns string) map[string]any { switch ns { @@ -159,7 +159,7 @@ func (s *perNsWorkerManagerSuite) TestEnabledButResolvedToOther() { Enabled: false, }).AnyTimes() - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("other1"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("other1")}) s.manager.namespaceCallback(ns, false) // main work happens in a goroutine @@ -176,7 +176,7 @@ func (s *perNsWorkerManagerSuite) TestEnabled() { Enabled: false, }).AnyTimes() - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("self"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("self")}) cli1 := mocksdk.NewMockClient(s.controller) s.cfactory.EXPECT().NewClient(matchOptions("ns1")).Return(cli1) wkr1 := mocksdk.NewMockWorker(s.controller) @@ -192,8 +192,22 @@ func (s *perNsWorkerManagerSuite) TestEnabled() { cli1.EXPECT().Close() } +/* + Given machine has ownership of the namespace + When name space change reported + Then worker should be started with configuration proportional to administratively set configuration ensuring + fair distribution across all the machines owning the namespace +*/ func (s *perNsWorkerManagerSuite) TestMultiplicity() { - ns := testns("ns3", enumspb.NAMESPACE_STATE_REGISTERED) // three workers + machinesOwningNamespace := []membership.HostInfo{ + membership.NewHostInfoFromAddress("other-1"), + membership.NewHostInfoFromAddress("other-2"), + membership.NewHostInfoFromAddress("self"), + } + desiredWorkersNumber := 6 // should match mock configuration in line 88 + expectedMultiplicity := 2 + + ns := testns("ns3", enumspb.NAMESPACE_STATE_REGISTERED) s.cmp1.EXPECT().DedicatedWorkerOptions(gomock.Any()).Return(&workercommon.PerNSDedicatedWorkerOptions{ Enabled: true, @@ -202,9 +216,7 @@ func (s *perNsWorkerManagerSuite) TestMultiplicity() { Enabled: false, }).AnyTimes() - s.serviceResolver.EXPECT().Lookup("ns3/0").Return(membership.NewHostInfoFromAddress("self"), nil) - s.serviceResolver.EXPECT().Lookup("ns3/1").Return(membership.NewHostInfoFromAddress("other"), nil) - s.serviceResolver.EXPECT().Lookup("ns3/2").Return(membership.NewHostInfoFromAddress("self"), nil) + s.serviceResolver.EXPECT().LookupN("ns3", desiredWorkersNumber).Return(machinesOwningNamespace) cli1 := mocksdk.NewMockClient(s.controller) s.cfactory.EXPECT().NewClient(matchOptions("ns3")).Return(cli1) wkr1 := mocksdk.NewMockWorker(s.controller) @@ -215,10 +227,11 @@ func (s *perNsWorkerManagerSuite) TestMultiplicity() { s.Equal(2000, options.MaxConcurrentLocalActivityExecutionSize) s.Equal(2000, options.MaxConcurrentActivityExecutionSize) }).Return(wkr1) - s.cmp1.EXPECT().Register(wkr1, ns, workercommon.RegistrationDetails{TotalWorkers: 3, Multiplicity: 2}) + s.cmp1.EXPECT().Register(wkr1, ns, workercommon.RegistrationDetails{TotalWorkers: desiredWorkersNumber, Multiplicity: expectedMultiplicity}) wkr1.EXPECT().Start() s.manager.namespaceCallback(ns, false) + time.Sleep(50 * time.Millisecond) wkr1.EXPECT().Stop() @@ -233,30 +246,42 @@ func (s *perNsWorkerManagerSuite) TestOptions() { s.cmp1.EXPECT().DedicatedWorkerOptions(gomock.Any()).Return(&workercommon.PerNSDedicatedWorkerOptions{ Enabled: true, }).AnyTimes() + s.cmp2.EXPECT().DedicatedWorkerOptions(gomock.Any()).Return(&workercommon.PerNSDedicatedWorkerOptions{ Enabled: false, }).AnyTimes() - s.serviceResolver.EXPECT().Lookup(gomock.Any()).Return(membership.NewHostInfoFromAddress("self"), nil).AnyTimes() + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{ + membership.NewHostInfoFromAddress("self"), + }).AnyTimes() + s.serviceResolver.EXPECT().LookupN("ns2", 2).Return([]membership.HostInfo{ + membership.NewHostInfoFromAddress("self"), + }).AnyTimes() + s.serviceResolver.EXPECT().LookupN("ns3", 6).Return([]membership.HostInfo{ + membership.NewHostInfoFromAddress("self"), + }).AnyTimes() + cli1 := mocksdk.NewMockClient(s.controller) s.cfactory.EXPECT().NewClient(matchOptions("ns1")).Return(cli1) - cli2 := mocksdk.NewMockClient(s.controller) - s.cfactory.EXPECT().NewClient(matchOptions("ns2")).Return(cli2) - cli3 := mocksdk.NewMockClient(s.controller) - s.cfactory.EXPECT().NewClient(matchOptions("ns3")).Return(cli3) wkr := mocksdk.NewMockWorker(s.controller) s.cfactory.EXPECT().NewWorker(matchStrict{cli1}, primitives.PerNSWorkerTaskQueue, gomock.Any()).Do(func(_, _ any, options sdkworker.Options) { s.Equal(100, options.MaxConcurrentWorkflowTaskPollers) s.Equal(2, options.MaxConcurrentActivityTaskPollers) s.Equal(0.0, options.WorkerLocalActivitiesPerSecond) }).Return(wkr) + + cli2 := mocksdk.NewMockClient(s.controller) + s.cfactory.EXPECT().NewClient(matchOptions("ns2")).Return(cli2) s.cfactory.EXPECT().NewWorker(matchStrict{cli2}, primitives.PerNSWorkerTaskQueue, gomock.Any()).Do(func(_, _ any, options sdkworker.Options) { s.Equal(4, options.MaxConcurrentWorkflowTaskPollers) s.Equal(200.0, options.WorkerLocalActivitiesPerSecond) s.Equal(7500*time.Millisecond, options.StickyScheduleToStartTimeout) }).Return(wkr) + + cli3 := mocksdk.NewMockClient(s.controller) + s.cfactory.EXPECT().NewClient(matchOptions("ns3")).Return(cli3) s.cfactory.EXPECT().NewWorker(matchStrict{cli3}, primitives.PerNSWorkerTaskQueue, gomock.Any()).Do(func(_, _ any, options sdkworker.Options) { - s.Equal(6, options.MaxConcurrentWorkflowTaskPollers) + s.Equal(12, options.MaxConcurrentWorkflowTaskPollers) s.Equal(0.0, options.WorkerLocalActivitiesPerSecond) s.Equal(0*time.Millisecond, options.StickyScheduleToStartTimeout) }).Return(wkr) @@ -266,6 +291,7 @@ func (s *perNsWorkerManagerSuite) TestOptions() { s.manager.namespaceCallback(ns1, false) s.manager.namespaceCallback(ns2, false) s.manager.namespaceCallback(ns3, false) + time.Sleep(50 * time.Millisecond) wkr.EXPECT().Stop().AnyTimes() @@ -288,9 +314,8 @@ func (s *perNsWorkerManagerSuite) TestTwoNamespacesTwoComponents() { return &workercommon.PerNSDedicatedWorkerOptions{Enabled: ns.Name().String() == "ns1"} }).AnyTimes() - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("self"), nil) - s.serviceResolver.EXPECT().Lookup("ns2/0").Return(membership.NewHostInfoFromAddress("self"), nil) - s.serviceResolver.EXPECT().Lookup("ns2/1").Return(membership.NewHostInfoFromAddress("self"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("self")}) + s.serviceResolver.EXPECT().LookupN("ns2", 2).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("self")}) cli1 := mocksdk.NewMockClient(s.controller) cli2 := mocksdk.NewMockClient(s.controller) @@ -331,7 +356,7 @@ func (s *perNsWorkerManagerSuite) TestDeleteNs() { Enabled: false, }).AnyTimes() - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("self"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("self")}) cli1 := mocksdk.NewMockClient(s.controller) s.cfactory.EXPECT().NewClient(matchOptions("ns1")).Return(cli1) wkr1 := mocksdk.NewMockWorker(s.controller) @@ -351,7 +376,7 @@ func (s *perNsWorkerManagerSuite) TestDeleteNs() { // restore it nsRestored := testns("ns1", enumspb.NAMESPACE_STATE_REGISTERED) - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("self"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("self")}) cli2 := mocksdk.NewMockClient(s.controller) s.cfactory.EXPECT().NewClient(matchOptions("ns1")).Return(cli2) wkr2 := mocksdk.NewMockWorker(s.controller) @@ -380,13 +405,13 @@ func (s *perNsWorkerManagerSuite) TestMembershipChanged() { }).AnyTimes() // we don't own it at first - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("other"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("other")}) s.manager.namespaceCallback(ns, false) time.Sleep(50 * time.Millisecond) // now we own it - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("self"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("self")}) cli1 := mocksdk.NewMockClient(s.controller) s.cfactory.EXPECT().NewClient(matchOptions("ns1")).Return(cli1) wkr1 := mocksdk.NewMockWorker(s.controller) @@ -398,7 +423,7 @@ func (s *perNsWorkerManagerSuite) TestMembershipChanged() { time.Sleep(50 * time.Millisecond) // now we don't own it anymore - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("other"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("other")}) wkr1.EXPECT().Stop() cli1.EXPECT().Close() @@ -416,9 +441,9 @@ func (s *perNsWorkerManagerSuite) TestServiceResolverError() { Enabled: false, }).AnyTimes() - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(nil, errors.New("resolver error")) - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(nil, errors.New("resolver error again")) - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("self"), nil) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{}) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{}) + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("self")}) cli1 := mocksdk.NewMockClient(s.controller) s.cfactory.EXPECT().NewClient(matchOptions("ns1")).Return(cli1) @@ -445,7 +470,7 @@ func (s *perNsWorkerManagerSuite) TestStartWorkerError() { Enabled: false, }).AnyTimes() - s.serviceResolver.EXPECT().Lookup("ns1/0").Return(membership.NewHostInfoFromAddress("self"), nil).AnyTimes() + s.serviceResolver.EXPECT().LookupN("ns1", 1).Return([]membership.HostInfo{membership.NewHostInfoFromAddress("self")}).AnyTimes() cli1 := mocksdk.NewMockClient(s.controller) s.cfactory.EXPECT().NewClient(matchOptions("ns1")).Return(cli1) diff --git a/tests/simple_service_resolver.go b/tests/simple_service_resolver.go index ac0d15c7881..b894d4f336f 100644 --- a/tests/simple_service_resolver.go +++ b/tests/simple_service_resolver.go @@ -88,6 +88,14 @@ func (s *simpleResolver) Lookup(key string) (membership.HostInfo, error) { return s.hostInfos[idx], nil } +func (s *simpleResolver) LookupN(key string, _ int) []membership.HostInfo { + info, err := s.Lookup(key) + if err != nil { + return []membership.HostInfo{} + } + return []membership.HostInfo{info} +} + func (s *simpleResolver) AddListener(name string, notifyChannel chan<- *membership.ChangedEvent) error { s.mu.Lock() defer s.mu.Unlock()