From 9f65c2493f1234ded73664dbb5de0edc77619573 Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Tue, 1 Nov 2022 16:53:10 -0400 Subject: [PATCH] volumewatcher: prevent panic on nil volume (#15101) If a GC claim is written and then volume is deleted before the `volumewatcher` enters its run loop, we panic on the nil-pointer access. Simply doing a nil-check at the top of the loop reveals a race condition around shutting down the loop just as a new update is coming in. Have the parent `volumeswatcher` send an initial update on the channel before returning, so that we're still holding the lock. Update the watcher's `Stop` method to set the running state, which lets us avoid having a second context and makes stopping synchronous. This reduces the cases we have to handle in the run loop. Updated the tests now that we'll safely return from the goroutine and stop the runner in a larger set of cases. Ran the tests with the `-race` detection flag and fixed up any problems found here as well. --- .changelog/15101.txt | 3 + nomad/volumewatcher/volume_watcher.go | 40 ++---- nomad/volumewatcher/volumes_watcher.go | 8 ++ nomad/volumewatcher/volumes_watcher_test.go | 147 +++++++++++++------- 4 files changed, 113 insertions(+), 85 deletions(-) create mode 100644 .changelog/15101.txt diff --git a/.changelog/15101.txt b/.changelog/15101.txt new file mode 100644 index 000000000000..c76126f79181 --- /dev/null +++ b/.changelog/15101.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed race condition that can cause a panic when volume is garbage collected +``` diff --git a/nomad/volumewatcher/volume_watcher.go b/nomad/volumewatcher/volume_watcher.go index 82cd76fb8492..f0153e3a9f00 100644 --- a/nomad/volumewatcher/volume_watcher.go +++ b/nomad/volumewatcher/volume_watcher.go @@ -74,7 +74,6 @@ func (vw *volumeWatcher) Notify(v *structs.CSIVolume) { select { case vw.updateCh <- v: case <-vw.shutdownCtx.Done(): // prevent deadlock if we stopped - case <-vw.ctx.Done(): // prevent deadlock if we stopped } } @@ -83,17 +82,14 @@ func (vw *volumeWatcher) Start() { vw.wLock.Lock() defer vw.wLock.Unlock() vw.running = true - ctx, exitFn := context.WithCancel(vw.shutdownCtx) - vw.ctx = ctx - vw.exitFn = exitFn go vw.watch() } -// Stop stops watching the volume. This should be called whenever a -// volume's claims are fully reaped or the watcher is no longer needed. func (vw *volumeWatcher) Stop() { vw.logger.Trace("no more claims") - vw.exitFn() + vw.wLock.Lock() + defer vw.wLock.Unlock() + vw.running = false } func (vw *volumeWatcher) isRunning() bool { @@ -102,8 +98,6 @@ func (vw *volumeWatcher) isRunning() bool { select { case <-vw.shutdownCtx.Done(): return false - case <-vw.ctx.Done(): - return false default: return vw.running } @@ -113,12 +107,8 @@ func (vw *volumeWatcher) isRunning() bool { // Each pass steps the volume's claims through the various states of reaping // until the volume has no more claims eligible to be reaped. func (vw *volumeWatcher) watch() { - // always denormalize the volume and call reap when we first start - // the watcher so that we ensure we don't drop events that - // happened during leadership transitions and didn't get completed - // by the prior leader - vol := vw.getVolume(vw.v) - vw.volumeReap(vol) + defer vw.deleteFn() + defer vw.Stop() timer, stop := helper.NewSafeTimer(vw.quiescentTimeout) defer stop() @@ -129,31 +119,17 @@ func (vw *volumeWatcher) watch() { // context, so we can't stop the long-runner RPCs gracefully case <-vw.shutdownCtx.Done(): return - case <-vw.ctx.Done(): - return case vol := <-vw.updateCh: vol = vw.getVolume(vol) if vol == nil { - // We stop the goroutine whenever we have no more - // work, but only delete the watcher when the volume - // is gone to avoid racing the blocking query - vw.deleteFn() - vw.Stop() return } vw.volumeReap(vol) timer.Reset(vw.quiescentTimeout) case <-timer.C: - // Wait until the volume has "settled" before stopping - // this goroutine so that the race between shutdown and - // the parent goroutine sending on <-updateCh is pushed to - // after the window we most care about quick freeing of - // claims (and the GC job will clean up anything we miss) - vol = vw.getVolume(vol) - if vol == nil { - vw.deleteFn() - } - vw.Stop() + // Wait until the volume has "settled" before stopping this + // goroutine so that we can handle the burst of updates around + // freeing claims without having to spin it back up return } } diff --git a/nomad/volumewatcher/volumes_watcher.go b/nomad/volumewatcher/volumes_watcher.go index 5ae08e1ffb99..7f95ed659d45 100644 --- a/nomad/volumewatcher/volumes_watcher.go +++ b/nomad/volumewatcher/volumes_watcher.go @@ -188,6 +188,14 @@ func (w *Watcher) addLocked(v *structs.CSIVolume) (*volumeWatcher, error) { watcher := newVolumeWatcher(w, v) w.watchers[v.ID+v.Namespace] = watcher + + // Sending the first volume update here before we return ensures we've hit + // the run loop in the goroutine before freeing the lock. This prevents a + // race between shutting down the watcher and the blocking query. + // + // It also ensures that we don't drop events that happened during leadership + // transitions and didn't get completed by the prior leader + watcher.updateCh <- v return watcher, nil } diff --git a/nomad/volumewatcher/volumes_watcher_test.go b/nomad/volumewatcher/volumes_watcher_test.go index b4deda0c7b17..142b83dd4508 100644 --- a/nomad/volumewatcher/volumes_watcher_test.go +++ b/nomad/volumewatcher/volumes_watcher_test.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" + "github.com/shoenig/test/must" "github.com/stretchr/testify/require" ) @@ -17,7 +18,6 @@ import ( // to happen during leader step-up/step-down func TestVolumeWatch_EnableDisable(t *testing.T) { ci.Parallel(t) - require := require.New(t) srv := &MockRPCServer{} srv.state = state.TestStateStore(t) @@ -36,7 +36,7 @@ func TestVolumeWatch_EnableDisable(t *testing.T) { index++ err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) - require.NoError(err) + require.NoError(t, err) // need to have just enough of a volume and claim in place so that // the watcher doesn't immediately stop and unload itself @@ -46,22 +46,23 @@ func TestVolumeWatch_EnableDisable(t *testing.T) { } index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) - require.Eventually(func() bool { + require.NoError(t, err) + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() return 1 == len(watcher.watchers) }, time.Second, 10*time.Millisecond) watcher.SetEnabled(false, nil, "") - require.Equal(0, len(watcher.watchers)) + watcher.wlock.RLock() + defer watcher.wlock.RUnlock() + require.Equal(t, 0, len(watcher.watchers)) } // TestVolumeWatch_LeadershipTransition tests the correct behavior of // claim reaping across leader step-up/step-down func TestVolumeWatch_LeadershipTransition(t *testing.T) { ci.Parallel(t) - require := require.New(t) srv := &MockRPCServer{} srv.state = state.TestStateStore(t) @@ -79,25 +80,25 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { index++ err := srv.State().UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc}) - require.NoError(err) + require.NoError(t, err) watcher.SetEnabled(true, srv.State(), "") index++ err = srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) - require.NoError(err) + require.NoError(t, err) // we should get or start up a watcher when we get an update for // the volume from the state store - require.Eventually(func() bool { + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() return 1 == len(watcher.watchers) }, time.Second, 10*time.Millisecond) vol, _ = srv.State().CSIVolumeByID(nil, vol.Namespace, vol.ID) - require.Len(vol.PastClaims, 0, "expected to have 0 PastClaims") - require.Equal(srv.countCSIUnpublish, 0, "expected no CSI.Unpublish RPC calls") + require.Len(t, vol.PastClaims, 0, "expected to have 0 PastClaims") + require.Equal(t, srv.countCSIUnpublish, 0, "expected no CSI.Unpublish RPC calls") // trying to test a dropped watch is racy, so to reliably simulate // this condition, step-down the watcher first and then perform @@ -106,12 +107,14 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { // step-down (this is sync) watcher.SetEnabled(false, nil, "") - require.Equal(0, len(watcher.watchers)) + watcher.wlock.RLock() + require.Equal(t, 0, len(watcher.watchers)) + watcher.wlock.RUnlock() // allocation is now invalid index++ err = srv.State().DeleteEval(index, []string{}, []string{alloc.ID}) - require.NoError(err) + require.NoError(t, err) // emit a GC so that we have a volume change that's dropped claim := &structs.CSIVolumeClaim{ @@ -122,7 +125,7 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { } index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) // create a new watcher and enable it to simulate the leadership // transition @@ -130,23 +133,21 @@ func TestVolumeWatch_LeadershipTransition(t *testing.T) { watcher.quiescentTimeout = 100 * time.Millisecond watcher.SetEnabled(true, srv.State(), "") - require.Eventually(func() bool { + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() - return 1 == len(watcher.watchers) && - !watcher.watchers[vol.ID+vol.Namespace].isRunning() + return 0 == len(watcher.watchers) }, time.Second, 10*time.Millisecond) vol, _ = srv.State().CSIVolumeByID(nil, vol.Namespace, vol.ID) - require.Len(vol.PastClaims, 1, "expected to have 1 PastClaim") - require.Equal(srv.countCSIUnpublish, 1, "expected CSI.Unpublish RPC to be called") + require.Len(t, vol.PastClaims, 1, "expected to have 1 PastClaim") + require.Equal(t, srv.countCSIUnpublish, 1, "expected CSI.Unpublish RPC to be called") } // TestVolumeWatch_StartStop tests the start and stop of the watcher when // it receives notifcations and has completed its work func TestVolumeWatch_StartStop(t *testing.T) { ci.Parallel(t) - require := require.New(t) srv := &MockStatefulRPCServer{} srv.state = state.TestStateStore(t) @@ -155,7 +156,7 @@ func TestVolumeWatch_StartStop(t *testing.T) { watcher.quiescentTimeout = 100 * time.Millisecond watcher.SetEnabled(true, srv.State(), "") - require.Equal(0, len(watcher.watchers)) + require.Equal(t, 0, len(watcher.watchers)) plugin := mock.CSIPlugin() node := testNode(plugin, srv.State()) @@ -166,23 +167,22 @@ func TestVolumeWatch_StartStop(t *testing.T) { alloc2.ClientStatus = structs.AllocClientStatusRunning index++ err := srv.State().UpsertJob(structs.MsgTypeTestSetup, index, alloc1.Job) - require.NoError(err) + require.NoError(t, err) index++ err = srv.State().UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc1, alloc2}) - require.NoError(err) + require.NoError(t, err) - // register a volume + // register a volume and an unused volume vol := testVolume(plugin, alloc1, node.ID) index++ err = srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) - require.NoError(err) + require.NoError(t, err) // assert we get a watcher; there are no claims so it should immediately stop - require.Eventually(func() bool { + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() - return 1 == len(watcher.watchers) && - !watcher.watchers[vol.ID+vol.Namespace].isRunning() + return 0 == len(watcher.watchers) }, time.Second*2, 10*time.Millisecond) // claim the volume for both allocs @@ -195,11 +195,11 @@ func TestVolumeWatch_StartStop(t *testing.T) { index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) claim.AllocationID = alloc2.ID index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) // reap the volume and assert nothing has happened claim = &structs.CSIVolumeClaim{ @@ -208,41 +208,88 @@ func TestVolumeWatch_StartStop(t *testing.T) { } index++ err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) ws := memdb.NewWatchSet() vol, _ = srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) - require.Equal(2, len(vol.ReadAllocs)) + require.Equal(t, 2, len(vol.ReadAllocs)) // alloc becomes terminal + alloc1 = alloc1.Copy() alloc1.ClientStatus = structs.AllocClientStatusComplete index++ err = srv.State().UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc1}) - require.NoError(err) + require.NoError(t, err) index++ claim.State = structs.CSIVolumeClaimStateReadyToFree err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) - require.NoError(err) + require.NoError(t, err) - // 1 claim has been released and watcher stops - require.Eventually(func() bool { - ws := memdb.NewWatchSet() - vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) - return len(vol.ReadAllocs) == 1 && len(vol.PastClaims) == 0 + // watcher stops and 1 claim has been released + require.Eventually(t, func() bool { + watcher.wlock.RLock() + defer watcher.wlock.RUnlock() + return 0 == len(watcher.watchers) + }, time.Second*5, 10*time.Millisecond) + + vol, _ = srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) + must.Eq(t, 1, len(vol.ReadAllocs)) + must.Eq(t, 0, len(vol.PastClaims)) +} + +// TestVolumeWatch_Delete tests the stop of the watcher when it receives +// notifications around a deleted volume +func TestVolumeWatch_Delete(t *testing.T) { + ci.Parallel(t) + + srv := &MockStatefulRPCServer{} + srv.state = state.TestStateStore(t) + index := uint64(100) + watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "") + watcher.quiescentTimeout = 100 * time.Millisecond + + watcher.SetEnabled(true, srv.State(), "") + must.Eq(t, 0, len(watcher.watchers)) + + // register an unused volume + plugin := mock.CSIPlugin() + vol := mock.CSIVolume(plugin) + index++ + must.NoError(t, srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol})) + + // assert we get a watcher; there are no claims so it should immediately stop + require.Eventually(t, func() bool { + watcher.wlock.RLock() + defer watcher.wlock.RUnlock() + return 0 == len(watcher.watchers) }, time.Second*2, 10*time.Millisecond) - require.Eventually(func() bool { + // write a GC claim to the volume and then immediately delete, to + // potentially hit the race condition between updates and deletes + index++ + must.NoError(t, srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, + &structs.CSIVolumeClaim{ + Mode: structs.CSIVolumeClaimGC, + State: structs.CSIVolumeClaimStateReadyToFree, + })) + + index++ + must.NoError(t, srv.State().CSIVolumeDeregister( + index, vol.Namespace, []string{vol.ID}, false)) + + // the watcher should not be running + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() - return !watcher.watchers[vol.ID+vol.Namespace].isRunning() + return 0 == len(watcher.watchers) }, time.Second*5, 10*time.Millisecond) + } // TestVolumeWatch_RegisterDeregister tests the start and stop of // watchers around registration func TestVolumeWatch_RegisterDeregister(t *testing.T) { ci.Parallel(t) - require := require.New(t) srv := &MockStatefulRPCServer{} srv.state = state.TestStateStore(t) @@ -253,7 +300,7 @@ func TestVolumeWatch_RegisterDeregister(t *testing.T) { watcher.quiescentTimeout = 10 * time.Millisecond watcher.SetEnabled(true, srv.State(), "") - require.Equal(0, len(watcher.watchers)) + require.Equal(t, 0, len(watcher.watchers)) plugin := mock.CSIPlugin() alloc := mock.Alloc() @@ -263,18 +310,12 @@ func TestVolumeWatch_RegisterDeregister(t *testing.T) { vol := mock.CSIVolume(plugin) index++ err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) - require.NoError(err) + require.NoError(t, err) - // watcher should be started but immediately stopped - require.Eventually(func() bool { + // watcher should stop + require.Eventually(t, func() bool { watcher.wlock.RLock() defer watcher.wlock.RUnlock() - return 1 == len(watcher.watchers) + return 0 == len(watcher.watchers) }, time.Second, 10*time.Millisecond) - - require.Eventually(func() bool { - watcher.wlock.RLock() - defer watcher.wlock.RUnlock() - return !watcher.watchers[vol.ID+vol.Namespace].isRunning() - }, 1*time.Second, 10*time.Millisecond) }