From b0e79d4f98c29c436e5ac311bc5f7d622e3218e1 Mon Sep 17 00:00:00 2001 From: Zijian Date: Wed, 2 Oct 2024 20:02:38 +0000 Subject: [PATCH] Revert "Refactor PeerProvider & hashring interaction (#6296)" This reverts commit 9807d5df82c495632664cefa0fa0b731adfc203f. --- common/membership/hashring.go | 133 ++++----- common/membership/hashring_test.go | 256 +++++------------- common/membership/peerprovider_mock.go | 8 +- common/membership/resolver.go | 4 - common/membership/resolver_test.go | 1 - .../peerprovider/ringpopprovider/provider.go | 26 +- .../ringpopprovider/provider_test.go | 49 ---- 7 files changed, 141 insertions(+), 336 deletions(-) diff --git a/common/membership/hashring.go b/common/membership/hashring.go index 6eb1f77fba2..751f6031ded 100644 --- a/common/membership/hashring.go +++ b/common/membership/hashring.go @@ -22,7 +22,6 @@ package membership import ( "fmt" - "slices" "sort" "strings" "sync" @@ -45,6 +44,7 @@ import ( // ErrInsufficientHosts is thrown when there are not enough hosts to serve the request var ErrInsufficientHosts = &types.InternalServiceError{Message: "Not enough hosts to serve the request"} +var emptyEvent = &ChangedEvent{} const ( minRefreshInternal = time.Second * 4 @@ -59,14 +59,14 @@ type PeerProvider interface { GetMembers(service string) ([]HostInfo, error) WhoAmI() (HostInfo, error) SelfEvict() error - Subscribe(name string, handler func(ChangedEvent)) error + Subscribe(name string, notifyChannel chan<- *ChangedEvent) error } type ring struct { status int32 service string peerProvider PeerProvider - refreshChan chan struct{} + refreshChan chan *ChangedEvent shutdownCh chan struct{} shutdownWG sync.WaitGroup timeSource clock.TimeSource @@ -99,7 +99,7 @@ func newHashring( service: service, peerProvider: provider, shutdownCh: make(chan struct{}), - refreshChan: make(chan struct{}, 1), + refreshChan: make(chan *ChangedEvent), timeSource: timeSource, logger: logger, scope: scope, @@ -125,11 +125,11 @@ func (r *ring) Start() { ) { return } - if err := r.peerProvider.Subscribe(r.service, r.handleUpdates); err != nil { + if err := r.peerProvider.Subscribe(r.service, r.refreshChan); err != nil { r.logger.Fatal("subscribing to peer provider", tag.Error(err)) } - if err := r.refresh(); err != nil { + if _, err := r.refresh(); err != nil { r.logger.Fatal("failed to start service resolver", tag.Error(err)) } @@ -151,8 +151,8 @@ func (r *ring) Stop() { r.value.Store(emptyHashring()) r.subscribers.Lock() + defer r.subscribers.Unlock() r.subscribers.keys = make(map[string]chan<- *ChangedEvent) - r.subscribers.Unlock() close(r.shutdownCh) if success := common.AwaitWaitGroup(&r.shutdownWG, time.Minute); !success { @@ -166,10 +166,12 @@ func (r *ring) Lookup( ) (HostInfo, error) { addr, found := r.ring().Lookup(key) if !found { - r.signalSelf() + select { + case r.refreshChan <- emptyEvent: + default: + } return HostInfo{}, ErrInsufficientHosts } - r.members.RLock() defer r.members.RUnlock() host, ok := r.members.keys[addr] @@ -193,25 +195,12 @@ func (r *ring) Subscribe(watcher string, notifyChannel chan<- *ChangedEvent) err return nil } -func (r *ring) handleUpdates(event ChangedEvent) { - r.signalSelf() -} - -func (r *ring) signalSelf() { - var event struct{} - select { - case r.refreshChan <- event: - default: // channel already has an event, don't block - } -} - -func (r *ring) notifySubscribers(msg ChangedEvent) { +func (r *ring) notifySubscribers(msg *ChangedEvent) { r.subscribers.Lock() defer r.subscribers.Unlock() - for name, ch := range r.subscribers.keys { select { - case ch <- &msg: + case ch <- msg: default: r.logger.Error("subscriber notification failed", tag.Name(name)) } @@ -251,43 +240,43 @@ func (r *ring) Members() []HostInfo { return hosts } -func (r *ring) refresh() error { +func (r *ring) refresh() (refreshed bool, err error) { if r.members.refreshed.After(r.timeSource.Now().Add(-minRefreshInternal)) { // refreshed too frequently - return nil + return false, nil } members, err := r.peerProvider.GetMembers(r.service) if err != nil { - return fmt.Errorf("getting members from peer provider: %w", err) + return false, fmt.Errorf("getting members from peer provider: %w", err) } - newMembersMap := r.makeMembersMap(members) - - diff := r.diffMembers(newMembersMap) - if diff.Empty() { - return nil + r.members.Lock() + defer r.members.Unlock() + newMembersMap, changed := r.compareMembers(members) + if !changed { + return false, nil } ring := emptyHashring() ring.AddMembers(castToMembers(members)...) + + r.members.keys = newMembersMap + r.members.refreshed = r.timeSource.Now() r.value.Store(ring) r.logger.Info("refreshed ring members", tag.Value(members)) - r.updateMembersMap(newMembersMap) - - r.emitHashIdentifier() - r.notifySubscribers(diff) - - return nil + return true, nil } -func (r *ring) updateMembersMap(newMembers map[string]HostInfo) { - r.members.Lock() - defer r.members.Unlock() - - r.members.keys = newMembers - r.members.refreshed = r.timeSource.Now() +func (r *ring) refreshAndNotifySubscribers(event *ChangedEvent) { + refreshed, err := r.refresh() + if err != nil { + r.logger.Error("refreshing ring", tag.Error(err)) + } + if refreshed { + r.notifySubscribers(event) + } } func (r *ring) refreshRingWorker() { @@ -295,17 +284,15 @@ func (r *ring) refreshRingWorker() { refreshTicker := r.timeSource.NewTicker(defaultRefreshInterval) defer refreshTicker.Stop() - for { select { case <-r.shutdownCh: return - case <-r.refreshChan: // local signal or signal from provider - if err := r.refresh(); err != nil { - r.logger.Error("failed to refresh ring", tag.Error(err)) - } - case <-refreshTicker.Chan(): // periodically force refreshing membership - r.signalSelf() + case event := <-r.refreshChan: // local signal or signal from provider + r.refreshAndNotifySubscribers(event) + case <-refreshTicker.Chan(): // periodically refresh membership + r.emitHashIdentifier() + r.refreshAndNotifySubscribers(emptyEvent) } } } @@ -315,7 +302,11 @@ func (r *ring) ring() *hashring.HashRing { } func (r *ring) emitHashIdentifier() float64 { - members := r.Members() + members, err := r.peerProvider.GetMembers(r.service) + if err != nil { + r.logger.Error("Observed a problem getting peer members while emitting hash identifier metrics", tag.Error(err)) + return -1 + } self, err := r.peerProvider.WhoAmI() if err != nil { r.logger.Error("Observed a problem looking up self from the membership provider while emitting hash identifier metrics", tag.Error(err)) @@ -352,38 +343,22 @@ func (r *ring) emitHashIdentifier() float64 { return trimmedForMetric } -func (r *ring) makeMembersMap(members []HostInfo) map[string]HostInfo { - membersMap := make(map[string]HostInfo, len(members)) - for _, m := range members { - membersMap[m.GetAddress()] = m - } - return membersMap -} - -func (r *ring) diffMembers(newMembers map[string]HostInfo) ChangedEvent { - r.members.RLock() - defer r.members.RUnlock() - - var combinedChange ChangedEvent - - // find newly added hosts - for addr := range newMembers { - if _, found := r.members.keys[addr]; !found { - combinedChange.HostsAdded = append(combinedChange.HostsAdded, addr) +func (r *ring) compareMembers(members []HostInfo) (map[string]HostInfo, bool) { + changed := false + newMembersMap := make(map[string]HostInfo, len(members)) + for _, member := range members { + newMembersMap[member.GetAddress()] = member + if _, ok := r.members.keys[member.GetAddress()]; !ok { + changed = true } } - // find removed hosts for addr := range r.members.keys { - if _, found := newMembers[addr]; !found { - combinedChange.HostsRemoved = append(combinedChange.HostsRemoved, addr) + if _, ok := newMembersMap[addr]; !ok { + changed = true + break } } - - // order since it will most probably used in logs - slices.Sort(combinedChange.HostsAdded) - slices.Sort(combinedChange.HostsUpdated) - slices.Sort(combinedChange.HostsRemoved) - return combinedChange + return newMembersMap, changed } func castToMembers[T membership.Member](members []T) []membership.Member { diff --git a/common/membership/hashring_test.go b/common/membership/hashring_test.go index 507715c4dc6..f4f4b757877 100644 --- a/common/membership/hashring_test.go +++ b/common/membership/hashring_test.go @@ -24,31 +24,24 @@ package membership import ( "errors" - "fmt" "math/rand" - "runtime" "sync" "testing" "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" "go.uber.org/goleak" "go.uber.org/zap/zaptest/observer" "github.com/uber/cadence/common" "github.com/uber/cadence/common/clock" - "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/metrics" ) var letters = []rune("abcdefghijklmnopqrstuvwxyz") -const maxTestDuration = 5 * time.Second - func randSeq(n int) string { b := make([]rune, n) for i := range b { @@ -65,56 +58,46 @@ func randomHostInfo(n int) []HostInfo { return res } -func TestDiffMemberMakesCorrectDiff(t *testing.T) { +func testCompareMembers(t *testing.T, curr []HostInfo, new []HostInfo, hasDiff bool) { + hashring := &ring{} + currMembers := make(map[string]HostInfo, len(curr)) + for _, m := range curr { + currMembers[m.GetAddress()] = m + } + hashring.members.keys = currMembers + newMembers, changed := hashring.compareMembers(new) + assert.Equal(t, hasDiff, changed) + assert.Equal(t, len(new), len(newMembers)) + for _, m := range new { + _, ok := newMembers[m.GetAddress()] + assert.True(t, ok) + } +} + +func Test_ring_compareMembers(t *testing.T) { + tests := []struct { - name string - curr []HostInfo - new []HostInfo - expectedChange ChangedEvent + curr []HostInfo + new []HostInfo + hasDiff bool }{ - { - name: "empty and one added", - curr: []HostInfo{}, - new: []HostInfo{NewHostInfo("a")}, - expectedChange: ChangedEvent{HostsAdded: []string{"a"}}, - }, - { - name: "non-empty and added", - curr: []HostInfo{NewHostInfo("a")}, - new: []HostInfo{NewHostInfo("a"), NewHostInfo("b")}, - expectedChange: ChangedEvent{HostsAdded: []string{"b"}}, - }, - { - name: "empty and nothing has changed", - curr: []HostInfo{}, - new: []HostInfo{}, - expectedChange: ChangedEvent{}, - }, - { - name: "multiple hosts, but no change", - curr: []HostInfo{NewHostInfo("a"), NewHostInfo("b"), NewHostInfo("c")}, - new: []HostInfo{NewHostInfo("c"), NewHostInfo("b"), NewHostInfo("a")}, - expectedChange: ChangedEvent{}, - }, - - { - name: "multiple hosts, add/delete", - curr: []HostInfo{NewHostInfo("a"), NewHostInfo("b"), NewHostInfo("c")}, - new: []HostInfo{NewHostInfo("b"), NewHostInfo("e"), NewHostInfo("f")}, - expectedChange: ChangedEvent{HostsRemoved: []string{"a", "c"}, HostsAdded: []string{"e", "f"}}, - }, + {curr: []HostInfo{}, new: []HostInfo{NewHostInfo("a")}, hasDiff: true}, + {curr: []HostInfo{}, new: []HostInfo{NewHostInfo("a"), NewHostInfo("b")}, hasDiff: true}, + {curr: []HostInfo{NewHostInfo("a")}, new: []HostInfo{NewHostInfo("a"), NewHostInfo("b")}, hasDiff: true}, + {curr: []HostInfo{}, new: []HostInfo{}, hasDiff: false}, + {curr: []HostInfo{NewHostInfo("a")}, new: []HostInfo{NewHostInfo("a")}, hasDiff: false}, + // order doesn't matter. + {curr: []HostInfo{NewHostInfo("a"), NewHostInfo("b")}, new: []HostInfo{NewHostInfo("b"), NewHostInfo("a")}, hasDiff: false}, + // member has left the ring + {curr: []HostInfo{NewHostInfo("a"), NewHostInfo("b"), NewHostInfo("c")}, new: []HostInfo{NewHostInfo("b"), NewHostInfo("a")}, hasDiff: true}, + // ring becomes empty + {curr: []HostInfo{NewHostInfo("a"), NewHostInfo("b"), NewHostInfo("c")}, new: []HostInfo{}, hasDiff: true}, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &ring{} - currMembers := r.makeMembersMap(tt.curr) - r.members.keys = currMembers - - combinedChange := r.diffMembers(r.makeMembersMap(tt.new)) - assert.Equal(t, tt.expectedChange, combinedChange) - }) + testCompareMembers(t, tt.curr, tt.new, tt.hasDiff) } + } type hashringTestData struct { @@ -159,126 +142,16 @@ func (td *hashringTestData) startHashRing() { td.hashRing.Start() } -func (td *hashringTestData) bypassRefreshRatelimiter() { - td.hashRing.members.refreshed = time.Now().AddDate(0, 0, -1) -} - func TestFailedLookupWillAskProvider(t *testing.T) { td := newHashringTestData(t) - var wg sync.WaitGroup - wg.Add(2) td.mockPeerProvider.EXPECT().Subscribe(gomock.Any(), gomock.Any()).Times(1) - td.mockPeerProvider.EXPECT().GetMembers("test-service"). - Do(func(string) { - // we expect first call on hashring creation - // the second call should be initiated by failed Lookup - wg.Done() - }).Times(2) + td.mockPeerProvider.EXPECT().GetMembers("test-service").Times(1) td.startHashRing() _, err := td.hashRing.Lookup("a") - assert.Error(t, err) - - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration), "Failed Lookup should lead to refresh") -} - -func TestFailingToSubscribeIsFatal(t *testing.T) { - defer goleak.VerifyNone(t) - td := newHashringTestData(t) - - // we need to intercept logger calls, use mock - mockLogger := &log.MockLogger{} - td.hashRing.logger = mockLogger - - mockLogger.On("Fatal", mock.Anything, mock.Anything).Run( - func(arguments mock.Arguments) { - // we need to stop goroutine like log.Fatal() does with an entire program - runtime.Goexit() - }, - ).Times(1) - - td.mockPeerProvider.EXPECT(). - Subscribe(gomock.Any(), gomock.Any()). - Return(errors.New("can't subscribe")) - - // because we use runtime.Goexit() we need to call .Start in a separate goroutine - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - td.hashRing.Start() - }() - - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration), "must be finished - failed to subscribe") - require.True(t, mockLogger.AssertExpectations(t), "log.Fatal must be called") -} - -func TestHandleUpdatesNeverBlocks(t *testing.T) { - td := newHashringTestData(t) - - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - td.hashRing.handleUpdates(ChangedEvent{}) - wg.Done() - }() - } - - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration), "handleUpdates should never block") -} -func TestHandlerSchedulesUpdates(t *testing.T) { - td := newHashringTestData(t) - - var wg sync.WaitGroup - td.mockPeerProvider.EXPECT().Subscribe(gomock.Any(), gomock.Any()).Times(1) - td.mockPeerProvider.EXPECT().GetMembers("test-service").DoAndReturn(func(service string) ([]HostInfo, error) { - wg.Done() - fmt.Println("GetMembers called") - return randomHostInfo(5), nil - }).Times(2) - td.mockPeerProvider.EXPECT().WhoAmI().AnyTimes() - - wg.Add(1) // we expect 1st GetMembers to be called during hashring start - td.startHashRing() - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration), "GetMembers must be called") - - wg.Add(1) // another call to GetMembers should happen because of handleUpdate - td.bypassRefreshRatelimiter() - td.hashRing.handleUpdates(ChangedEvent{}) - - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration), "GetMembers must be called again") -} - -func TestFailedRefreshLogsError(t *testing.T) { - td := newHashringTestData(t) - - var wg sync.WaitGroup - td.mockPeerProvider.EXPECT().Subscribe(gomock.Any(), gomock.Any()).Times(1) - td.mockPeerProvider.EXPECT().GetMembers("test-service").DoAndReturn(func(service string) ([]HostInfo, error) { - wg.Done() - return randomHostInfo(5), nil - }).Times(1) - td.mockPeerProvider.EXPECT().WhoAmI().AnyTimes() - - wg.Add(1) // we expect 1st GetMembers to be called during hashring start - td.startHashRing() - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration), "GetMembers must be called") - - td.mockPeerProvider.EXPECT().GetMembers("test-service").DoAndReturn(func(service string) ([]HostInfo, error) { - wg.Done() - return nil, errors.New("GetMembers failed") - }).Times(1) - - wg.Add(1) // another call to GetMembers should happen because of handleUpdate - td.bypassRefreshRatelimiter() - td.hashRing.handleUpdates(ChangedEvent{}) - - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration), "GetMembers must be called again (and fail)") - td.hashRing.Stop() - assert.Equal(t, 1, td.observedLogs.FilterMessageSnippet("failed to refresh ring").Len()) + assert.Error(t, err) } func TestRefreshUpdatesRingOnlyWhenRingHasChanged(t *testing.T) { @@ -286,13 +159,15 @@ func TestRefreshUpdatesRingOnlyWhenRingHasChanged(t *testing.T) { td.mockPeerProvider.EXPECT().Subscribe(gomock.Any(), gomock.Any()).Times(1) td.mockPeerProvider.EXPECT().GetMembers("test-service").Times(1).Return(randomHostInfo(3), nil) - td.mockPeerProvider.EXPECT().WhoAmI().AnyTimes() // Start will also call .refresh() td.startHashRing() updatedAt := td.hashRing.members.refreshed + td.hashRing.refresh() + refreshed, err := td.hashRing.refresh() - assert.NoError(t, td.hashRing.refresh()) + assert.NoError(t, err) + assert.False(t, refreshed) assert.Equal(t, updatedAt, td.hashRing.members.refreshed) } @@ -301,11 +176,16 @@ func TestRefreshWillNotifySubscribers(t *testing.T) { var hostsToReturn []HostInfo td.mockPeerProvider.EXPECT().Subscribe(gomock.Any(), gomock.Any()).Times(1) - td.mockPeerProvider.EXPECT().GetMembers("test-service").DoAndReturn(func(service string) ([]HostInfo, error) { + td.mockPeerProvider.EXPECT().GetMembers("test-service").Times(2).DoAndReturn(func(service string) ([]HostInfo, error) { hostsToReturn = randomHostInfo(5) return hostsToReturn, nil - }).Times(2) - td.mockPeerProvider.EXPECT().WhoAmI().AnyTimes() + }) + + changed := &ChangedEvent{ + HostsAdded: []string{"a"}, + HostsUpdated: []string{"b"}, + HostsRemoved: []string{"c"}, + } td.startHashRing() @@ -320,16 +200,14 @@ func TestRefreshWillNotifySubscribers(t *testing.T) { defer wg.Done() changedEvent := <-changeCh changedEvent2 := <-changeCh - assert.NotEmpty(t, changedEvent, "changed event should never be empty") - assert.NotEmpty(t, changedEvent2, "changed event should never be empty") + assert.Equal(t, changed, changedEvent) + assert.Equal(t, changed, changedEvent2) }() - td.bypassRefreshRatelimiter() - td.hashRing.signalSelf() - - // wait until both subscribers will get notification - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration)) - + // to bypass internal check + td.hashRing.members.refreshed = time.Now().AddDate(0, 0, -1) + td.hashRing.refreshChan <- changed + wg.Wait() // wait until both subscribers will get notification // Test if internal members are updated assert.ElementsMatch(t, td.hashRing.Members(), hostsToReturn, "members should contain just-added nodes") } @@ -340,11 +218,11 @@ func TestSubscribersAreNotifiedPeriodically(t *testing.T) { var hostsToReturn []HostInfo td.mockPeerProvider.EXPECT().Subscribe(gomock.Any(), gomock.Any()).Times(1) - td.mockPeerProvider.EXPECT().GetMembers("test-service").DoAndReturn(func(service string) ([]HostInfo, error) { + td.mockPeerProvider.EXPECT().GetMembers("test-service").Times(3).DoAndReturn(func(service string) ([]HostInfo, error) { // we have to change members since subscribers are only notified on change hostsToReturn = randomHostInfo(5) return hostsToReturn, nil - }).Times(2) + }) td.mockPeerProvider.EXPECT().WhoAmI().AnyTimes() td.startHashRing() @@ -357,13 +235,13 @@ func TestSubscribersAreNotifiedPeriodically(t *testing.T) { go func() { defer wg.Done() event := <-changeCh - assert.NotEmpty(t, event, "changed event should never be empty") + assert.Empty(t, event, "event should be empty when periodical update happens") }() td.mockTimeSource.BlockUntil(1) // we should wait until ticker(defaultRefreshInterval) is created td.mockTimeSource.Advance(defaultRefreshInterval) // and only then to advance time - require.True(t, common.AwaitWaitGroup(&wg, maxTestDuration)) // wait until subscriber will get notification + wg.Wait() // wait until subscriber will get notification // Test if internal members are updated assert.ElementsMatch(t, td.hashRing.Members(), hostsToReturn, "members should contain just-added nodes") @@ -398,6 +276,7 @@ func TestUnsubcribeDeletes(t *testing.T) { assert.Equal(t, 1, len(td.hashRing.subscribers.keys)) assert.NoError(t, td.hashRing.Unsubscribe("testservice1")) assert.Equal(t, 0, len(td.hashRing.subscribers.keys)) + } func TestMemberCountReturnsNumber(t *testing.T) { @@ -417,9 +296,10 @@ func TestMemberCountReturnsNumber(t *testing.T) { func TestErrorIsPropagatedWhenProviderFails(t *testing.T) { td := newHashringTestData(t) - td.mockPeerProvider.EXPECT().GetMembers(gomock.Any()).Return(nil, errors.New("provider failure")) + td.mockPeerProvider.EXPECT().GetMembers(gomock.Any()).Return(nil, errors.New("error")) - assert.ErrorContains(t, td.hashRing.refresh(), "provider failure") + _, err := td.hashRing.refresh() + assert.Error(t, err) } func TestStopWillStopProvider(t *testing.T) { @@ -439,7 +319,6 @@ func TestLookupAndRefreshRaceCondition(t *testing.T) { td.mockPeerProvider.EXPECT().GetMembers("test-service").AnyTimes().DoAndReturn(func(service string) ([]HostInfo, error) { return randomHostInfo(5), nil }) - td.mockPeerProvider.EXPECT().WhoAmI().AnyTimes() td.startHashRing() wg.Add(2) @@ -451,8 +330,10 @@ func TestLookupAndRefreshRaceCondition(t *testing.T) { }() go func() { for i := 0; i < 50; i++ { - td.bypassRefreshRatelimiter() - assert.NoError(t, td.hashRing.refresh()) + // to bypass internal check + td.hashRing.members.refreshed = time.Now().AddDate(0, 0, -1) + _, err := td.hashRing.refresh() + assert.NoError(t, err) } wg.Done() }() @@ -505,11 +386,10 @@ func TestEmitHashringView(t *testing.T) { td.mockPeerProvider.EXPECT().GetMembers("test-service").DoAndReturn(func(service string) ([]HostInfo, error) { return testInput.hosts, testInput.lookuperr }) + td.mockPeerProvider.EXPECT().WhoAmI().DoAndReturn(func() (HostInfo, error) { return testInput.selfInfo, testInput.selfErr - }).AnyTimes() - - require.NoError(t, td.hashRing.refresh()) + }) assert.Equal(t, testInput.expectedResult, td.hashRing.emitHashIdentifier()) }) } diff --git a/common/membership/peerprovider_mock.go b/common/membership/peerprovider_mock.go index dc6c6d53c3f..d7b48deb405 100644 --- a/common/membership/peerprovider_mock.go +++ b/common/membership/peerprovider_mock.go @@ -109,17 +109,17 @@ func (mr *MockPeerProviderMockRecorder) Stop() *gomock.Call { } // Subscribe mocks base method. -func (m *MockPeerProvider) Subscribe(name string, handler func(ChangedEvent)) error { +func (m *MockPeerProvider) Subscribe(name string, notifyChannel chan<- *ChangedEvent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Subscribe", name, handler) + ret := m.ctrl.Call(m, "Subscribe", name, notifyChannel) ret0, _ := ret[0].(error) return ret0 } // Subscribe indicates an expected call of Subscribe. -func (mr *MockPeerProviderMockRecorder) Subscribe(name, handler interface{}) *gomock.Call { +func (mr *MockPeerProviderMockRecorder) Subscribe(name, notifyChannel interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockPeerProvider)(nil).Subscribe), name, handler) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockPeerProvider)(nil).Subscribe), name, notifyChannel) } // WhoAmI mocks base method. diff --git a/common/membership/resolver.go b/common/membership/resolver.go index 96443c7d493..86e237012ff 100644 --- a/common/membership/resolver.go +++ b/common/membership/resolver.go @@ -233,7 +233,3 @@ func (rpo *MultiringResolver) MemberCount(service string) (int, error) { } return ring.MemberCount(), nil } - -func (ce *ChangedEvent) Empty() bool { - return len(ce.HostsAdded) == 0 && len(ce.HostsUpdated) == 0 && len(ce.HostsRemoved) == 0 -} diff --git a/common/membership/resolver_test.go b/common/membership/resolver_test.go index 810a5045190..794f621e507 100644 --- a/common/membership/resolver_test.go +++ b/common/membership/resolver_test.go @@ -65,7 +65,6 @@ func TestMethodsAreRoutedToARing(t *testing.T) { } pp.EXPECT().GetMembers("test-worker").Return(hosts, nil).Times(1) - pp.EXPECT().WhoAmI().AnyTimes() r, err := a.getRing("test-worker") r.refresh() diff --git a/common/peerprovider/ringpopprovider/provider.go b/common/peerprovider/ringpopprovider/provider.go index 050dda3c546..1a21f1ecd5c 100644 --- a/common/peerprovider/ringpopprovider/provider.go +++ b/common/peerprovider/ringpopprovider/provider.go @@ -52,7 +52,7 @@ type ( logger log.Logger portmap membership.PortMap mu sync.RWMutex - subscribers map[string]func(membership.ChangedEvent) + subscribers map[string]chan<- *membership.ChangedEvent } ) @@ -118,7 +118,7 @@ func NewRingpopProvider( logger: logger, portmap: portMap, ringpop: rp, - subscribers: map[string]func(membership.ChangedEvent){}, + subscribers: map[string]chan<- *membership.ChangedEvent{}, } } @@ -165,22 +165,26 @@ func (r *Provider) HandleEvent(event events.Event) { return } - change := membership.ChangedEvent{ + r.logger.Info("Received a ringpop ring changed event") + // Marshall the event object into the required type + change := &membership.ChangedEvent{ HostsAdded: e.ServersAdded, HostsUpdated: e.ServersUpdated, HostsRemoved: e.ServersRemoved, } - r.logger.Info("Received a ringpop ring changed event", tag.MembershipChangeEvent(change)) - r.notifySubscribers(change) -} -func (r *Provider) notifySubscribers(event membership.ChangedEvent) { + // Notify subscribers r.mu.RLock() defer r.mu.RUnlock() - for _, handler := range r.subscribers { - handler(event) + for name, ch := range r.subscribers { + select { + case ch <- change: + default: + r.logger.Error("Failed to send listener notification, channel full", tag.Subscriber(name)) + } } + } func (r *Provider) SelfEvict() error { @@ -284,7 +288,7 @@ func (r *Provider) Stop() { } // Subscribe allows to be subscribed for ring changes -func (r *Provider) Subscribe(name string, handler func(membership.ChangedEvent)) error { +func (r *Provider) Subscribe(name string, notifyChannel chan<- *membership.ChangedEvent) error { r.mu.Lock() defer r.mu.Unlock() @@ -293,7 +297,7 @@ func (r *Provider) Subscribe(name string, handler func(membership.ChangedEvent)) return fmt.Errorf("%q already subscribed to ringpop provider", name) } - r.subscribers[name] = handler + r.subscribers[name] = notifyChannel return nil } diff --git a/common/peerprovider/ringpopprovider/provider_test.go b/common/peerprovider/ringpopprovider/provider_test.go index ba45742fa45..9f10ff3a7c4 100644 --- a/common/peerprovider/ringpopprovider/provider_test.go +++ b/common/peerprovider/ringpopprovider/provider_test.go @@ -27,9 +27,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/uber/ringpop-go/events" "github.com/uber/tchannel-go" "go.uber.org/goleak" @@ -37,8 +34,6 @@ import ( "github.com/uber/cadence/common/membership" ) -const testServiceName = "test-service" - type srvAndCh struct { service string ch *tchannel.Channel @@ -141,50 +136,6 @@ func TestRingpopProvider(t *testing.T) { } } -func TestSubscribeAndNotify(t *testing.T) { - provider := NewRingpopProvider(testServiceName, nil, nil, nil, testlogger.New(t)) - - ringpopEvent := events.RingChangedEvent{ - ServersAdded: []string{"aa", "bb", "cc"}, - ServersUpdated: []string{"dd"}, - ServersRemoved: []string{"ee", "ff"}, - } - expectedEvent := membership.ChangedEvent{ - HostsAdded: ringpopEvent.ServersAdded, - HostsUpdated: ringpopEvent.ServersUpdated, - HostsRemoved: ringpopEvent.ServersRemoved, - } - - var calls1, calls2 int - require.NoError(t, - provider.Subscribe("subscriber1", - func(ev membership.ChangedEvent) { - calls1++ - assert.Equal(t, ev, expectedEvent) - }, - )) - - require.NoError(t, - provider.Subscribe("subscriber2", - func(ev membership.ChangedEvent) { - calls2++ - assert.Equal(t, ev, expectedEvent) - }, - )) - - require.Error(t, - provider.Subscribe( - "subscriber2", - func(membership.ChangedEvent) { t.Error("Should never be called") }, - ), - "Subscribe doesn't allow duplicate names", - ) - - provider.HandleEvent(ringpopEvent) - assert.Equal(t, 1, calls1, "every subscriber must have been called once") - assert.Equal(t, 1, calls2, "every subscriber must have been called once") -} - func createAndListenChannels(serviceName string, n int) ([]*srvAndCh, func(), error) { var res []*srvAndCh cleanupFn := func(srvs []*srvAndCh) func() {