diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index a8695982be2f..36b2c9bf7aa5 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -647,6 +647,16 @@ func (v *CSIVolume) Unpublish(args *structs.CSIVolumeUnpublishRequest, reply *st claim := args.Claim + // we need to checkpoint when we first get the claim to ensure we've set the + // initial "past claim" state, otherwise a client that unpublishes (skipping + // the node unpublish b/c it's done that work) fail to get written if the + // controller unpublish fails. + vol = vol.Copy() + err = v.checkpointClaim(vol, claim) + if err != nil { + return err + } + // previous checkpoints may have set the past claim state already. // in practice we should never see CSIVolumeClaimStateControllerDetached // but having an option for the state makes it easy to add a checkpoint @@ -693,14 +703,18 @@ RELEASE_CLAIM: func (v *CSIVolume) nodeUnpublishVolume(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim) error { v.logger.Trace("node unpublish", "vol", vol.ID) - store := v.srv.fsm.State() + // We need a new snapshot after each checkpoint + snap, err := v.srv.fsm.State().Snapshot() + if err != nil { + return err + } // If the node has been GC'd or is down, we can't send it a node // unpublish. We need to assume the node has unpublished at its // end. If it hasn't, any controller unpublish will potentially // hang or error and need to be retried. if claim.NodeID != "" { - node, err := store.NodeByID(memdb.NewWatchSet(), claim.NodeID) + node, err := snap.NodeByID(memdb.NewWatchSet(), claim.NodeID) if err != nil { return err } @@ -723,7 +737,7 @@ func (v *CSIVolume) nodeUnpublishVolume(vol *structs.CSIVolume, claim *structs.C // The RPC sent from the 'nomad node detach' command or GC won't have an // allocation ID set so we try to unpublish every terminal or invalid // alloc on the node, all of which will be in PastClaims after denormalizing - vol, err := store.CSIVolumeDenormalize(memdb.NewWatchSet(), vol) + vol, err = snap.CSIVolumeDenormalize(memdb.NewWatchSet(), vol) if err != nil { return err } @@ -793,10 +807,15 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str return nil } - state := v.srv.fsm.State() + // We need a new snapshot after each checkpoint + snap, err := v.srv.fsm.State().Snapshot() + if err != nil { + return err + } + ws := memdb.NewWatchSet() - plugin, err := state.CSIPluginByID(ws, vol.PluginID) + plugin, err := snap.CSIPluginByID(ws, vol.PluginID) if err != nil { return fmt.Errorf("could not query plugin: %v", err) } else if plugin == nil { @@ -808,7 +827,7 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str return nil } - vol, err = state.CSIVolumeDenormalize(ws, vol) + vol, err = snap.CSIVolumeDenormalize(ws, vol) if err != nil { return err } diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 9f24fc12c3b4..4a4decf86857 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -575,7 +575,7 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) { // setup: create an alloc that will claim our volume alloc := mock.BatchAlloc() alloc.NodeID = tc.nodeID - alloc.ClientStatus = structs.AllocClientStatusFailed + alloc.ClientStatus = structs.AllocClientStatusRunning otherAlloc := mock.BatchAlloc() otherAlloc.NodeID = tc.otherNodeID @@ -585,7 +585,7 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) { must.NoError(t, state.UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc, otherAlloc})) - // setup: claim the volume for our alloc + // setup: claim the volume for our to-be-failed alloc claim := &structs.CSIVolumeClaim{ AllocationID: alloc.ID, NodeID: tc.nodeID, @@ -623,10 +623,19 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) { }, } + alloc = alloc.Copy() + alloc.ClientStatus = structs.AllocClientStatusFailed + index++ + must.NoError(t, state.UpsertAllocs(structs.MsgTypeTestSetup, index, + []*structs.Allocation{alloc})) + err = msgpackrpc.CallWithCodec(codec, "CSIVolume.Unpublish", req, &structs.CSIVolumeUnpublishResponse{}) - vol, volErr := state.CSIVolumeByID(nil, ns, volID) + snap, snapErr := state.Snapshot() + must.NoError(t, snapErr) + + vol, volErr := snap.CSIVolumeByID(nil, ns, volID) must.NoError(t, volErr) must.NotNil(t, vol)