diff --git a/src/aggregator/aggregator/placement_mgr.go b/src/aggregator/aggregator/placement_mgr.go index 2355f77dc9..79d4ff3d37 100644 --- a/src/aggregator/aggregator/placement_mgr.go +++ b/src/aggregator/aggregator/placement_mgr.go @@ -221,21 +221,16 @@ func (mgr *placementManager) placementWithLock() (placement.ActiveStagedPlacemen return nil, nil, errPlacementManagerNotOpenOrClosed } - // NB(xichen): avoid using defer here because this is called on the write path - // for every incoming metric and defered return and func execution is expensive. - stagedPlacement, onStagedPlacementDoneFn, err := mgr.placementWatcher.ActiveStagedPlacement() + stagedPlacement, err := mgr.placementWatcher.ActiveStagedPlacement() if err != nil { mgr.metrics.activeStagedPlacementErrors.Inc(1) return nil, nil, err } - placement, onPlacementDoneFn, err := stagedPlacement.ActivePlacement() + placement, err := stagedPlacement.ActivePlacement() if err != nil { - onStagedPlacementDoneFn() mgr.metrics.activePlacementErrors.Inc(1) return nil, nil, err } - onPlacementDoneFn() - onStagedPlacementDoneFn() return stagedPlacement, placement, nil } diff --git a/src/aggregator/client/tcp_client.go b/src/aggregator/client/tcp_client.go index b2484e055e..31df577893 100644 --- a/src/aggregator/client/tcp_client.go +++ b/src/aggregator/client/tcp_client.go @@ -229,20 +229,18 @@ func (c *TCPClient) WriteForwarded( // ActivePlacement returns a copy of the currently active placement and its version. func (c *TCPClient) ActivePlacement() (placement.Placement, int, error) { - stagedPlacement, onStagedPlacementDoneFn, err := c.placementWatcher.ActiveStagedPlacement() + stagedPlacement, err := c.placementWatcher.ActiveStagedPlacement() if err != nil { return nil, 0, err } - defer onStagedPlacementDoneFn() if stagedPlacement == nil { return nil, 0, errNilPlacement } - placement, onPlacementDoneFn, err := stagedPlacement.ActivePlacement() + placement, err := stagedPlacement.ActivePlacement() if err != nil { return nil, 0, err } - defer onPlacementDoneFn() return placement.Clone(), stagedPlacement.Version(), nil } @@ -250,11 +248,10 @@ func (c *TCPClient) ActivePlacement() (placement.Placement, int, error) { // ActivePlacementVersion returns a copy of the currently active placement version. It is a far less expensive call // than ActivePlacement, as it does not clone the placement. func (c *TCPClient) ActivePlacementVersion() (int, error) { - stagedPlacement, onStagedPlacementDoneFn, err := c.placementWatcher.ActiveStagedPlacement() + stagedPlacement, err := c.placementWatcher.ActiveStagedPlacement() if err != nil { return 0, err } - defer onStagedPlacementDoneFn() if stagedPlacement == nil { return 0, errNilPlacement } @@ -281,17 +278,15 @@ func (c *TCPClient) write( timeNanos int64, payload payloadUnion, ) error { - stagedPlacement, onStagedPlacementDoneFn, err := c.placementWatcher.ActiveStagedPlacement() + stagedPlacement, err := c.placementWatcher.ActiveStagedPlacement() if err != nil { return err } if stagedPlacement == nil { - onStagedPlacementDoneFn() return errNilPlacement } - placement, onPlacementDoneFn, err := stagedPlacement.ActivePlacement() + placement, err := stagedPlacement.ActivePlacement() if err != nil { - onStagedPlacementDoneFn() return err } var ( @@ -327,8 +322,6 @@ func (c *TCPClient) write( c.metrics.dropped.Inc(1) } - onPlacementDoneFn() - onStagedPlacementDoneFn() return multiErr.FinalError() } diff --git a/src/aggregator/client/tcp_client_test.go b/src/aggregator/client/tcp_client_test.go index 3990e7ada3..37ddd34ce6 100644 --- a/src/aggregator/client/tcp_client_test.go +++ b/src/aggregator/client/tcp_client_test.go @@ -222,7 +222,7 @@ func TestTCPClientWriteUntimedMetricActiveStagedPlacementError(t *testing.T) { errActiveStagedPlacementError := errors.New("error active staged placement") watcher := placement.NewMockStagedPlacementWatcher(ctrl) watcher.EXPECT().ActiveStagedPlacement(). - Return(nil, nil, errActiveStagedPlacementError). + Return(nil, errActiveStagedPlacementError). MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.placementWatcher = watcher @@ -247,7 +247,7 @@ func TestTCPClientWriteUntimedMetricActiveStagedPlacementNil(t *testing.T) { watcher := placement.NewMockStagedPlacementWatcher(ctrl) watcher.EXPECT().ActiveStagedPlacement(). - Return(nil, func() {}, nil). + Return(nil, nil). MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.placementWatcher = watcher @@ -272,9 +272,9 @@ func TestTCPClientWriteUntimedMetricActivePlacementError(t *testing.T) { errActivePlacementError := errors.New("error active placement") stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(nil, nil, errActivePlacementError).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(nil, errActivePlacementError).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.placementWatcher = watcher @@ -316,9 +316,9 @@ func TestTCPClientWriteUntimedMetricSuccess(t *testing.T) { }). MinTimes(1) stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.nowFn = func() time.Time { return time.Unix(0, testNowNanos) } c.writerMgr = writerMgr @@ -381,9 +381,9 @@ func TestTCPClientWriteUntimedMetricPartialError(t *testing.T) { }). MinTimes(1) stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.nowFn = func() time.Time { return time.Unix(0, testNowNanos) } c.writerMgr = writerMgr @@ -408,9 +408,9 @@ func TestTCPClientWriteUntimedMetricBeforeShardCutover(t *testing.T) { var instancesRes []placement.Instance stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.shardCutoverWarmupDuration = time.Second c.nowFn = func() time.Time { return time.Unix(0, testCutoverNanos-1).Add(-time.Second) } @@ -428,9 +428,9 @@ func TestTCPClientWriteUntimedMetricAfterShardCutoff(t *testing.T) { var instancesRes []placement.Instance stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.shardCutoffLingerDuration = time.Second c.nowFn = func() time.Time { return time.Unix(0, testCutoffNanos+1).Add(time.Second) } @@ -466,9 +466,9 @@ func TestTCPClientWriteTimedMetricSuccess(t *testing.T) { }). MinTimes(1) stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.nowFn = func() time.Time { return time.Unix(0, testNowNanos) } c.writerMgr = writerMgr @@ -517,9 +517,9 @@ func TestTCPClientWriteTimedMetricPartialError(t *testing.T) { }). MinTimes(1) stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.nowFn = func() time.Time { return time.Unix(0, testNowNanos) } c.writerMgr = writerMgr @@ -564,9 +564,9 @@ func TestTCPClientWriteForwardedMetricSuccess(t *testing.T) { }). MinTimes(1) stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.nowFn = func() time.Time { return time.Unix(0, testNowNanos) } c.writerMgr = writerMgr @@ -615,9 +615,9 @@ func TestTCPClientWriteForwardedMetricPartialError(t *testing.T) { }). MinTimes(1) stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.nowFn = func() time.Time { return time.Unix(0, testNowNanos) } c.writerMgr = writerMgr @@ -662,9 +662,9 @@ func TestTCPClientWritePassthroughMetricSuccess(t *testing.T) { }). MinTimes(1) stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.nowFn = func() time.Time { return time.Unix(0, testNowNanos) } c.writerMgr = writerMgr @@ -713,9 +713,9 @@ func TestTCPClientWritePassthroughMetricPartialError(t *testing.T) { }). MinTimes(1) stagedPlacement := placement.NewMockActiveStagedPlacement(ctrl) - stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, func() {}, nil).MinTimes(1) + stagedPlacement.EXPECT().ActivePlacement().Return(testPlacement, nil).MinTimes(1) watcher := placement.NewMockStagedPlacementWatcher(ctrl) - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() {}, nil).MinTimes(1) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).MinTimes(1) c := mustNewTestTCPClient(t, testOptions()) c.nowFn = func() time.Time { return time.Unix(0, testNowNanos) } c.writerMgr = writerMgr @@ -820,25 +820,22 @@ func TestTCPClientActivePlacement(t *testing.T) { mockPl = placement.NewMockPlacement(ctrl) stagedPlacement = placement.NewMockActiveStagedPlacement(ctrl) watcher = placement.NewMockStagedPlacementWatcher(ctrl) - doneCalls int ) c.placementWatcher = watcher - watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, func() { doneCalls++ }, nil).Times(2) + watcher.EXPECT().ActiveStagedPlacement().Return(stagedPlacement, nil).Times(2) stagedPlacement.EXPECT().Version().Return(42).Times(2) - stagedPlacement.EXPECT().ActivePlacement().Return(mockPl, func() { doneCalls++ }, nil) + stagedPlacement.EXPECT().ActivePlacement().Return(mockPl, nil) mockPl.EXPECT().Clone().Return(emptyPl) pl, v, err := c.ActivePlacement() assert.NoError(t, err) assert.Equal(t, 42, v) - assert.Equal(t, 2, doneCalls) assert.Equal(t, emptyPl, pl) v, err = c.ActivePlacementVersion() assert.NoError(t, err) assert.Equal(t, 42, v) - assert.Equal(t, 3, doneCalls) } func TestTCPClientInitAndClose(t *testing.T) { diff --git a/src/cluster/placement/placement_mock.go b/src/cluster/placement/placement_mock.go index 1aff1e3b57..8c38d043e6 100644 --- a/src/cluster/placement/placement_mock.go +++ b/src/cluster/placement/placement_mock.go @@ -1,7 +1,7 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/m3db/m3/src/cluster/placement/types.go -// Copyright (c) 2020 Uber Technologies, Inc. +// Copyright (c) 2021 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -875,13 +875,12 @@ func (mr *MockStagedPlacementWatcherMockRecorder) Watch() *gomock.Call { } // ActiveStagedPlacement mocks base method -func (m *MockStagedPlacementWatcher) ActiveStagedPlacement() (ActiveStagedPlacement, DoneFn, error) { +func (m *MockStagedPlacementWatcher) ActiveStagedPlacement() (ActiveStagedPlacement, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ActiveStagedPlacement") ret0, _ := ret[0].(ActiveStagedPlacement) - ret1, _ := ret[1].(DoneFn) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(error) + return ret0, ret1 } // ActiveStagedPlacement indicates an expected call of ActiveStagedPlacement @@ -1119,13 +1118,12 @@ func (m *MockActiveStagedPlacement) EXPECT() *MockActiveStagedPlacementMockRecor } // ActivePlacement mocks base method -func (m *MockActiveStagedPlacement) ActivePlacement() (Placement, DoneFn, error) { +func (m *MockActiveStagedPlacement) ActivePlacement() (Placement, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ActivePlacement") ret0, _ := ret[0].(Placement) - ret1, _ := ret[1].(DoneFn) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(error) + return ret0, ret1 } // ActivePlacement indicates an expected call of ActivePlacement diff --git a/src/cluster/placement/staged_placement.go b/src/cluster/placement/staged_placement.go index 070562dfce..3eb4399e3d 100644 --- a/src/cluster/placement/staged_placement.go +++ b/src/cluster/placement/staged_placement.go @@ -23,7 +23,6 @@ package placement import ( "errors" "sort" - "sync" "go.uber.org/atomic" @@ -34,24 +33,23 @@ import ( var ( errNoApplicablePlacement = errors.New("no applicable placement found") errActiveStagedPlacementClosed = errors.New("active staged placement is closed") + errPlacementInvalidType = errors.New("corrupt placement") + + _noPlacements Placements ) type activeStagedPlacement struct { - sync.RWMutex - - placements Placements - version int + closed atomic.Bool + expiring atomic.Int32 nowFn clock.NowFn onPlacementsAddedFn OnPlacementsAddedFn onPlacementsRemovedFn OnPlacementsRemovedFn - - expiring atomic.Int32 - closed bool - doneFn DoneFn + placements atomic.Value + version int } func newActiveStagedPlacement( - placements []Placement, + placements Placements, version int, opts ActiveStagedPlacementOptions, ) *activeStagedPlacement { @@ -59,13 +57,12 @@ func newActiveStagedPlacement( opts = NewActiveStagedPlacementOptions() } p := &activeStagedPlacement{ - placements: placements, version: version, nowFn: opts.ClockOptions().NowFn(), onPlacementsAddedFn: opts.OnPlacementsAddedFn(), onPlacementsRemovedFn: opts.OnPlacementsRemovedFn(), } - p.doneFn = p.onPlacementDone + p.placements.Store(placements) if p.onPlacementsAddedFn != nil { p.onPlacementsAddedFn(placements) @@ -74,27 +71,18 @@ func newActiveStagedPlacement( return p } -func (p *activeStagedPlacement) ActivePlacement() (Placement, DoneFn, error) { - p.RLock() - placement, err := p.activePlacementWithLock(p.nowFn().UnixNano()) - if err != nil { - p.RUnlock() - return nil, nil, err - } - return placement, p.doneFn, nil -} - func (p *activeStagedPlacement) Close() error { - p.Lock() - defer p.Unlock() - - if p.closed { + if !p.closed.CAS(false, true) { return errActiveStagedPlacementClosed } if p.onPlacementsRemovedFn != nil { - p.onPlacementsRemovedFn(p.placements) + pl, ok := p.placements.Load().(Placements) + if ok { + p.onPlacementsRemovedFn(pl) + } } - p.placements = nil + p.placements.Store(_noPlacements) + return nil } @@ -102,18 +90,25 @@ func (p *activeStagedPlacement) Version() int { return p.version } -func (p *activeStagedPlacement) onPlacementDone() { p.RUnlock() } - -func (p *activeStagedPlacement) activePlacementWithLock(timeNanos int64) (Placement, error) { - if p.closed { +func (p *activeStagedPlacement) ActivePlacement() (Placement, error) { + if p.closed.Load() { return nil, errActiveStagedPlacementClosed } - idx := p.placements.ActiveIndex(timeNanos) + + // placements themselves are subject to mutability races, but even in historical design, there was no actual + // need to enforce it, as long as we ensure callback ordering, and that is still the case. + placements, ok := p.placements.Load().(Placements) + if !ok { + return nil, errPlacementInvalidType + } + + idx := placements.ActiveIndex(p.nowFn().UnixNano()) if idx < 0 { return nil, errNoApplicablePlacement } - placement := p.placements[idx] - // If the placement that's in effect is not the first placment, expire the stale ones. + + placement := placements[idx] + // If the placement that's in effect is not the first placement, expire the stale ones. if idx > 0 && p.expiring.CAS(0, 1) { go p.expire() } @@ -123,28 +118,27 @@ func (p *activeStagedPlacement) activePlacementWithLock(timeNanos int64) (Placem func (p *activeStagedPlacement) expire() { // NB(xichen): this improves readability at the slight cost of lambda capture // because this code path is triggered very infrequently. - cleanup := func() { - p.Unlock() - p.expiring.Store(0) + defer p.expiring.Store(0) + + if p.closed.Load() { + return } - p.Lock() - defer cleanup() - if p.closed { + placements, ok := p.placements.Load().(Placements) + if !ok { return } - idx := p.placements.ActiveIndex(p.nowFn().UnixNano()) + + idx := placements.ActiveIndex(p.nowFn().UnixNano()) if idx <= 0 { return } + if p.onPlacementsRemovedFn != nil { - p.onPlacementsRemovedFn(p.placements[:idx]) + p.onPlacementsRemovedFn(placements[:idx]) } - n := copy(p.placements[0:], p.placements[idx:]) - for i := n; i < len(p.placements); i++ { - p.placements[i] = nil - } - p.placements = p.placements[:n] + + p.placements.Store(placements[idx:]) } type stagedPlacement struct { diff --git a/src/cluster/placement/staged_placement_test.go b/src/cluster/placement/staged_placement_test.go index 06a3ce53dd..b7a78fc220 100644 --- a/src/cluster/placement/staged_placement_test.go +++ b/src/cluster/placement/staged_placement_test.go @@ -299,69 +299,72 @@ func TestNewActiveStagedPlacement(t *testing.T) { ) ap := newActiveStagedPlacement(testActivePlacements, 0, opts) require.Equal(t, len(testActivePlacements), len(allInstances)) - require.Equal(t, len(testActivePlacements), len(ap.placements)) + placements := ap.placements.Load().(Placements) //nolint:errcheck + require.Equal(t, len(testActivePlacements), len(placements)) for i := 0; i < len(testActivePlacements); i++ { require.Equal(t, allInstances[i], testActivePlacements[i].Instances()) - validateSnapshot(t, testActivePlacements[i], ap.placements[i]) + validateSnapshot(t, testActivePlacements[i], placements[i]) } } func TestActiveStagedPlacementActivePlacementClosed(t *testing.T) { + pl := append(Placements{}, testActivePlacements...) p := &activeStagedPlacement{ - placements: append([]Placement{}, testActivePlacements...), - nowFn: time.Now, - closed: true, + nowFn: time.Now, } - _, _, err := p.ActivePlacement() + p.placements.Store(pl) + p.closed.Store(true) + + _, err := p.ActivePlacement() require.Equal(t, errActiveStagedPlacementClosed, err) } func TestActiveStagedPlacementNoApplicablePlacementFound(t *testing.T) { + pl := append(Placements{}, testActivePlacements...) p := &activeStagedPlacement{ - placements: append([]Placement{}, testActivePlacements...), - nowFn: func() time.Time { return time.Unix(0, 0) }, + nowFn: func() time.Time { return time.Unix(0, 0) }, } - _, _, err := p.ActivePlacement() + p.placements.Store(pl) + _, err := p.ActivePlacement() require.Equal(t, errNoApplicablePlacement, err) } func TestActiveStagedPlacementActivePlacementFoundWithExpiry(t *testing.T) { var removedInstances [][]Instance + pl := append(Placements{}, testActivePlacements...) p := &activeStagedPlacement{ - placements: append([]Placement{}, testActivePlacements...), - nowFn: func() time.Time { return time.Unix(0, 99999) }, + nowFn: func() time.Time { return time.Unix(0, 99999) }, onPlacementsRemovedFn: func(placements []Placement) { for _, placement := range placements { removedInstances = append(removedInstances, placement.Instances()) } }, } - p.doneFn = p.onPlacementDone - placement, doneFn, err := p.ActivePlacement() + p.placements.Store(pl) + placement, err := p.ActivePlacement() require.NoError(t, err) require.Equal(t, testActivePlacements[1], placement) - doneFn() + var placements Placements for { - p.RLock() - numPlacements := len(p.placements) - p.RUnlock() + placements = p.placements.Load().(Placements) //nolint:errcheck + numPlacements := len(placements) + if numPlacements == 1 { break } time.Sleep(10 * time.Millisecond) } - validateSnapshot(t, testActivePlacements[1], p.placements[0]) + validateSnapshot(t, testActivePlacements[1], placements[0]) require.Equal(t, 1, len(removedInstances)) require.Equal(t, testActivePlacements[0].Instances(), removedInstances[0]) } func TestActiveStagedPlacementCloseAlreadyClosed(t *testing.T) { p := &activeStagedPlacement{ - placements: append([]Placement{}, testActivePlacements...), - nowFn: time.Now, - closed: true, + nowFn: time.Now, } + p.closed.Store(true) require.Equal(t, errActiveStagedPlacementClosed, p.Close()) } @@ -393,16 +396,17 @@ func TestActiveStagedPlacementCloseSuccess(t *testing.T) { func TestActiveStagedPlacementExpireAlreadyClosed(t *testing.T) { var removedInstances [][]Instance + pl := append(Placements{}, testActivePlacements...) p := &activeStagedPlacement{ - placements: append([]Placement{}, testActivePlacements...), - nowFn: func() time.Time { return time.Unix(0, 99999) }, - closed: true, + nowFn: func() time.Time { return time.Unix(0, 99999) }, onPlacementsRemovedFn: func(placements []Placement) { for _, placement := range placements { removedInstances = append(removedInstances, placement.Instances()) } }, } + p.placements.Store(pl) + p.closed.Store(true) p.expiring.Store(1) p.expire() require.Equal(t, int32(0), p.expiring.Load()) @@ -425,7 +429,7 @@ func testActiveStagedPlacementVersionWhileExpiring(t *testing.T) { ranCleanup atomic.Bool ) - p := newActiveStagedPlacement(append([]Placement{}, testActivePlacements...), 42, nil) + p := newActiveStagedPlacement(append(Placements{}, testActivePlacements...), 42, nil) p.nowFn = func() time.Time { return time.Unix(0, testActivePlacements[len(testActivePlacements)-1].CutoverNanos()+1) } @@ -445,10 +449,9 @@ func testActiveStagedPlacementVersionWhileExpiring(t *testing.T) { } }() - pl, doneFn, err := p.ActivePlacement() + pl, err := p.ActivePlacement() require.NoError(t, err) require.NotNil(t, pl) - require.NotNil(t, doneFn) // active placement is not the first in the list - expiration of past // placements must be triggered @@ -461,8 +464,6 @@ func testActiveStagedPlacementVersionWhileExpiring(t *testing.T) { t.Fatalf("test timed out, deadlock?") } - // release placement lock to unblock expiration process - doneFn() select { case <-doneCh: case <-time.After(time.Second): @@ -485,15 +486,16 @@ func testActiveStagedPlacementVersionWhileExpiring(t *testing.T) { func TestActiveStagedPlacementExpireAlreadyExpired(t *testing.T) { var removedInstances [][]Instance + pl := append(Placements{}, testActivePlacements...) p := &activeStagedPlacement{ - placements: append([]Placement{}, testActivePlacements...), - nowFn: func() time.Time { return time.Unix(0, 0) }, + nowFn: func() time.Time { return time.Unix(0, 0) }, onPlacementsRemovedFn: func(placements []Placement) { for _, placement := range placements { removedInstances = append(removedInstances, placement.Instances()) } }, } + p.placements.Store(pl) p.expiring.Store(1) p.expire() require.Equal(t, int32(0), p.expiring.Load()) @@ -502,21 +504,24 @@ func TestActiveStagedPlacementExpireAlreadyExpired(t *testing.T) { func TestActiveStagedPlacementExpireSuccess(t *testing.T) { var removedInstances [][]Instance + pl := append(Placements{}, testActivePlacements...) p := &activeStagedPlacement{ - placements: append([]Placement{}, testActivePlacements...), - nowFn: func() time.Time { return time.Unix(0, 99999) }, + nowFn: func() time.Time { return time.Unix(0, 99999) }, onPlacementsRemovedFn: func(placements []Placement) { for _, placement := range placements { removedInstances = append(removedInstances, placement.Instances()) } }, } + p.placements.Store(pl) p.expiring.Store(1) p.expire() require.Equal(t, int32(0), p.expiring.Load()) require.Equal(t, [][]Instance{testActivePlacements[0].Instances()}, removedInstances) - require.Equal(t, 1, len(p.placements)) - validateSnapshot(t, testActivePlacements[1], p.placements[0]) + + placements := p.placements.Load().(Placements) //nolint:errcheck + require.Equal(t, 1, len(placements)) + validateSnapshot(t, testActivePlacements[1], placements[0]) } func TestStagedPlacementNilProto(t *testing.T) { @@ -565,7 +570,8 @@ func TestStagedPlacementActiveStagedPlacement(t *testing.T) { {t: 99999, placements: pss.placements[1:]}, } { ap := pss.ActiveStagedPlacement(input.t) - require.Equal(t, input.placements, ap.(*activeStagedPlacement).placements) + placements := ap.(*activeStagedPlacement).placements.Load().(Placements) //nolint:errcheck + require.Equal(t, input.placements, placements) } } diff --git a/src/cluster/placement/staged_placement_watcher.go b/src/cluster/placement/staged_placement_watcher.go index 876934b54c..08b2ab5c55 100644 --- a/src/cluster/placement/staged_placement_watcher.go +++ b/src/cluster/placement/staged_placement_watcher.go @@ -24,6 +24,8 @@ import ( "errors" "sync" + "go.uber.org/atomic" + "github.com/m3db/m3/src/cluster/generated/proto/placementpb" "github.com/m3db/m3/src/cluster/kv" "github.com/m3db/m3/src/cluster/kv/util/runtime" @@ -34,26 +36,22 @@ var ( errNilValue = errors.New("nil value received") errPlacementWatcherIsNotWatching = errors.New("placement watcher is not watching") errPlacementWatcherIsWatching = errors.New("placement watcher is watching") -) - -type placementWatcherState int - -const ( - placementWatcherNotWatching placementWatcherState = iota - placementWatcherWatching + errPlacementWatcherCastError = errors.New("interface cast failed, unexpected placement type") ) type stagedPlacementWatcher struct { - sync.RWMutex - runtime.Value - + mtx sync.Mutex nowFn clock.NowFn + placement atomic.Value placementOpts ActiveStagedPlacementOptions - doneFn DoneFn + value runtime.Value + watching atomic.Bool +} - state placementWatcherState - proto *placementpb.PlacementSnapshots - placement ActiveStagedPlacement +// plValue is a wrapper for type-safe interface storage in atomic.Value, +// as the concrete type has to be the same for each .Store() call. +type plValue struct { + p ActiveStagedPlacement } // NewStagedPlacementWatcher creates a new staged placement watcher. @@ -61,9 +59,7 @@ func NewStagedPlacementWatcher(opts StagedPlacementWatcherOptions) StagedPlaceme watcher := &stagedPlacementWatcher{ nowFn: opts.ClockOptions().NowFn(), placementOpts: opts.ActiveStagedPlacementOptions(), - proto: &placementpb.PlacementSnapshots{}, } - watcher.doneFn = watcher.onActiveStagedPlacementDone valueOpts := runtime.NewOptions(). SetInstrumentOptions(opts.InstrumentOptions()). @@ -71,86 +67,81 @@ func NewStagedPlacementWatcher(opts StagedPlacementWatcherOptions) StagedPlaceme SetKVStore(opts.StagedPlacementStore()). SetUnmarshalFn(watcher.toStagedPlacement). SetProcessFn(watcher.process) - watcher.Value = runtime.NewValue(opts.StagedPlacementKey(), valueOpts) + watcher.value = runtime.NewValue(opts.StagedPlacementKey(), valueOpts) return watcher } func (t *stagedPlacementWatcher) Watch() error { - t.Lock() - if t.state != placementWatcherNotWatching { - t.Unlock() + if !t.watching.CAS(false, true) { return errPlacementWatcherIsWatching } - t.state = placementWatcherWatching - t.Unlock() - // NB(xichen): we watch the placementWatcher updates outside the lock because - // otherwise the initial update will trigger the process() callback, - // which attempts to acquire the same lock, causing a deadlock. - return t.Value.Watch() + return t.value.Watch() } -func (t *stagedPlacementWatcher) ActiveStagedPlacement() (ActiveStagedPlacement, DoneFn, error) { - t.RLock() - if t.state != placementWatcherWatching { - t.RUnlock() - return nil, nil, errPlacementWatcherIsNotWatching +func (t *stagedPlacementWatcher) ActiveStagedPlacement() (ActiveStagedPlacement, error) { + if !t.watching.Load() { + return nil, errPlacementWatcherIsNotWatching + } + + pl := t.placement.Load() + placement, ok := pl.(plValue) + + if !ok { + return nil, errPlacementWatcherCastError } - return t.placement, t.doneFn, nil + + return placement.p, nil } func (t *stagedPlacementWatcher) Unwatch() error { - t.Lock() - if t.state != placementWatcherWatching { - t.Unlock() + if !t.watching.CAS(true, false) { return errPlacementWatcherIsNotWatching } - t.state = placementWatcherNotWatching - if t.placement != nil { - t.placement.Close() + + pl := t.placement.Load() + placement, ok := pl.(plValue) + if ok && placement.p != nil { + placement.p.Close() //nolint:errcheck } - t.placement = nil - t.Unlock() - // NB(xichen): we unwatch the updates outside the lock to avoid deadlock - // due to placementWatcher contending for the runtime value lock and the - // runtime updating goroutine attempting to acquire placementWatcher lock. - t.Value.Unwatch() + t.value.Unwatch() return nil } -func (t *stagedPlacementWatcher) onActiveStagedPlacementDone() { t.RUnlock() } - func (t *stagedPlacementWatcher) toStagedPlacement(value kv.Value) (interface{}, error) { - t.Lock() - defer t.Unlock() - - if t.state != placementWatcherWatching { + if !t.watching.Load() { return nil, errPlacementWatcherIsNotWatching } if value == nil { return nil, errNilValue } - t.proto.Reset() - if err := value.Unmarshal(t.proto); err != nil { + + var proto placementpb.PlacementSnapshots + if err := value.Unmarshal(&proto); err != nil { return nil, err } version := value.Version() - return NewStagedPlacementFromProto(version, t.proto, t.placementOpts) + + return NewStagedPlacementFromProto(version, &proto, t.placementOpts) } func (t *stagedPlacementWatcher) process(value interface{}) error { - t.Lock() - defer t.Unlock() + t.mtx.Lock() // serialize value processing + defer t.mtx.Unlock() - if t.state != placementWatcherWatching { + if !t.watching.Load() { return errPlacementWatcherIsNotWatching } ps := value.(StagedPlacement) placement := ps.ActiveStagedPlacement(t.nowFn().UnixNano()) - if t.placement != nil { - t.placement.Close() + + pl := t.placement.Load() + oldPlacement, ok := pl.(plValue) + if ok && oldPlacement.p != nil { + oldPlacement.p.Close() //nolint:errcheck } - t.placement = placement + + t.placement.Store(plValue{p: placement}) return nil } diff --git a/src/cluster/placement/staged_placement_watcher_test.go b/src/cluster/placement/staged_placement_watcher_test.go index bed04b105f..9e567b2207 100644 --- a/src/cluster/placement/staged_placement_watcher_test.go +++ b/src/cluster/placement/staged_placement_watcher_test.go @@ -21,11 +21,13 @@ package placement import ( + "runtime" "testing" "time" "github.com/m3db/m3/src/cluster/kv" "github.com/m3db/m3/src/cluster/kv/mem" + "github.com/m3db/m3/src/x/clock" "github.com/stretchr/testify/require" ) @@ -36,44 +38,40 @@ const ( func TestStagedPlacementWatcherWatchAlreadyWatching(t *testing.T) { watcher, _ := testStagedPlacementWatcher(t) - watcher.state = placementWatcherWatching + watcher.watching.Store(true) require.Equal(t, errPlacementWatcherIsWatching, watcher.Watch()) } func TestStagedPlacementWatcherWatchSuccess(t *testing.T) { watcher, _ := testStagedPlacementWatcher(t) - watcher.state = placementWatcherNotWatching require.NoError(t, watcher.Watch()) } func TestStagedPlacementWatcherActiveStagedPlacementNotWatching(t *testing.T) { watcher, _ := testStagedPlacementWatcher(t) - watcher.state = placementWatcherNotWatching - _, _, err := watcher.ActiveStagedPlacement() + _, err := watcher.ActiveStagedPlacement() require.Equal(t, errPlacementWatcherIsNotWatching, err) } func TestStagedPlacementWatcherActiveStagedPlacementSuccess(t *testing.T) { watcher, _ := testStagedPlacementWatcher(t) - watcher.state = placementWatcherWatching - _, doneFn, err := watcher.ActiveStagedPlacement() + require.NoError(t, watcher.Watch()) + _, err := watcher.ActiveStagedPlacement() require.NoError(t, err) - doneFn() } func TestStagedPlacementWatcherUnwatchNotWatching(t *testing.T) { watcher, _ := testStagedPlacementWatcher(t) - watcher.state = placementWatcherNotWatching require.Equal(t, errPlacementWatcherIsNotWatching, watcher.Unwatch()) } func TestStagedPlacementWatcherUnwatchSuccess(t *testing.T) { watcher, _ := testStagedPlacementWatcher(t) - watcher.state = placementWatcherWatching + watcher.watching.Store(true) require.NoError(t, watcher.Unwatch()) - require.Equal(t, placementWatcherNotWatching, watcher.state) - require.Nil(t, watcher.placement) + require.False(t, watcher.watching.Load()) + require.Nil(t, watcher.placement.Load()) } func TestStagedPlacementWatcherToStagedPlacementNotWatching(t *testing.T) { @@ -84,7 +82,7 @@ func TestStagedPlacementWatcherToStagedPlacementNotWatching(t *testing.T) { func TestStagedPlacementWatcherToPlacementNilValue(t *testing.T) { watcher, _ := testStagedPlacementWatcher(t) - watcher.state = placementWatcherWatching + watcher.watching.Store(true) _, err := watcher.toStagedPlacement(nil) require.Equal(t, errNilValue, err) } @@ -97,7 +95,7 @@ func TestStagedPlacementWatcherToStagedPlacementUnmarshalError(t *testing.T) { func TestStagedPlacementWatcherToStagedPlacementSuccess(t *testing.T) { watcher, store := testStagedPlacementWatcher(t) - watcher.state = placementWatcherWatching + watcher.watching.Store(true) val, err := store.Get(testStagedPlacementKey) require.NoError(t, err) p, err := watcher.toStagedPlacement(val) @@ -129,11 +127,11 @@ func TestStagedPlacementWatcherProcessSuccess(t *testing.T) { pss, err := NewStagedPlacementFromProto(1, testStagedPlacementProto, opts) require.NoError(t, err) watcher, _ := testStagedPlacementWatcher(t) - watcher.state = placementWatcherWatching + watcher.watching.Store(true) watcher.nowFn = func() time.Time { return time.Unix(0, 99999) } - watcher.placement = &mockPlacement{ + watcher.placement.Store(plValue{p: &mockPlacement{ closeFn: func() error { numCloses++; return nil }, - } + }}) require.NoError(t, watcher.process(pss)) require.NotNil(t, watcher.placement) @@ -141,6 +139,53 @@ func TestStagedPlacementWatcherProcessSuccess(t *testing.T) { require.Equal(t, 1, numCloses) } +func BenchmarkStagedPlacementWatcherActiveStagedPlacement(b *testing.B) { + store := mem.NewStore() + _, err := store.SetIfNotExists(testStagedPlacementKey, testStagedPlacementProto) + if err != nil { + b.Fatal(err) + } + + watcherOpts := testStagedPlacementWatcherOptions().SetStagedPlacementStore(store) + watcher := NewStagedPlacementWatcher(watcherOpts) + + w, ok := watcher.(*stagedPlacementWatcher) + if !ok { + b.Fatal("type assertion failed") + } + w.watching.Store(true) + w.placement.Store(plValue{p: newActiveStagedPlacement( + testActivePlacements, + 0, + NewActiveStagedPlacementOptions().SetClockOptions( + clock.NewOptions().SetNowFn(func() time.Time { + return time.Unix(0, testActivePlacements[0].CutoverNanos()) + }), + ), + )}) + + var asp Placement + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + pl, err := watcher.ActiveStagedPlacement() + if err != nil { + b.Fatal(err) + } + + curpl, err := pl.ActivePlacement() + if err != nil { + b.Fatal(err) + } + + asp = curpl + } + }) + + runtime.KeepAlive(asp) +} + func testStagedPlacementWatcher(t *testing.T) (*stagedPlacementWatcher, kv.Store) { store := mem.NewStore() _, err := store.SetIfNotExists(testStagedPlacementKey, testStagedPlacementProto) @@ -164,8 +209,8 @@ type mockPlacement struct { closeFn closeFn } -func (mp *mockPlacement) ActivePlacement() (Placement, DoneFn, error) { - return nil, func() {}, nil +func (mp *mockPlacement) ActivePlacement() (Placement, error) { + return nil, nil } func (mp *mockPlacement) Close() error { return mp.closeFn() } diff --git a/src/cluster/placement/types.go b/src/cluster/placement/types.go index 29bef15dcf..486cbe2fe6 100644 --- a/src/cluster/placement/types.go +++ b/src/cluster/placement/types.go @@ -217,10 +217,9 @@ type StagedPlacementWatcher interface { // Watch starts watching the updates. Watch() error - // ActiveStagedPlacement returns the currently active staged placement, the - // callback function when the caller is done using the active staged placement, + // ActiveStagedPlacement returns the currently active staged placement // and any errors encountered. - ActiveStagedPlacement() (ActiveStagedPlacement, DoneFn, error) + ActiveStagedPlacement() (ActiveStagedPlacement, error) // Unwatch stops watching the updates. Unwatch() error @@ -269,7 +268,7 @@ type StagedPlacementWatcherOptions interface { type ActiveStagedPlacement interface { // ActivePlacement returns the currently active placement for a given time, the callback // function when the caller is done using the placement, and any errors encountered. - ActivePlacement() (Placement, DoneFn, error) + ActivePlacement() (Placement, error) // Version returns the version of the underlying staged placement. Version() int