From 775de0d1c2d31f4f0e23620d404e28a552abe9fe Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Thu, 30 Apr 2020 09:13:00 -0400 Subject: [PATCH] csi: move volume claim release into volumewatcher (#7794) This changeset adds a subsystem to run on the leader, similar to the deployment watcher or node drainer. The `Watcher` performs a blocking query on updates to the `CSIVolumes` table and triggers reaping of volume claims. This will avoid tying up scheduling workers by immediately sending volume claim workloads into their own loop, rather than blocking the scheduling workers in the core GC job doing things like talking to CSI controllers The volume watcher is enabled on leader step-up and disabled on leader step-down. The volume claim GC mechanism now makes an empty claim RPC for the volume to trigger an index bump. That in turn unblocks the blocking query in the volume watcher so it can assess which claims can be released for a volume. --- nomad/core_sched.go | 212 +---------- nomad/core_sched_test.go | 267 -------------- nomad/fsm.go | 34 +- nomad/interfaces.go | 11 - nomad/job_endpoint.go | 8 +- nomad/leader.go | 6 + nomad/node_endpoint.go | 2 +- nomad/node_endpoint_test.go | 2 +- nomad/server.go | 31 ++ nomad/state/state_store.go | 11 +- nomad/structs/csi.go | 4 + nomad/structs/structs.go | 1 + nomad/volumewatcher/batcher.go | 125 +++++++ nomad/volumewatcher/batcher_test.go | 85 +++++ nomad/volumewatcher/interfaces.go | 28 ++ nomad/volumewatcher/interfaces_test.go | 148 ++++++++ nomad/volumewatcher/volume_watcher.go | 382 ++++++++++++++++++++ nomad/volumewatcher/volume_watcher_test.go | 294 +++++++++++++++ nomad/volumewatcher/volumes_watcher.go | 232 ++++++++++++ nomad/volumewatcher/volumes_watcher_test.go | 310 ++++++++++++++++ nomad/volumewatcher_shim.go | 31 ++ 21 files changed, 1721 insertions(+), 503 deletions(-) delete mode 100644 nomad/interfaces.go create mode 100644 nomad/volumewatcher/batcher.go create mode 100644 nomad/volumewatcher/batcher_test.go create mode 100644 nomad/volumewatcher/interfaces.go create mode 100644 nomad/volumewatcher/interfaces_test.go create mode 100644 nomad/volumewatcher/volume_watcher.go create mode 100644 nomad/volumewatcher/volume_watcher_test.go create mode 100644 nomad/volumewatcher/volumes_watcher.go create mode 100644 nomad/volumewatcher/volumes_watcher_test.go create mode 100644 nomad/volumewatcher_shim.go diff --git a/nomad/core_sched.go b/nomad/core_sched.go index 0b1d7e62420f..08f65e5bf823 100644 --- a/nomad/core_sched.go +++ b/nomad/core_sched.go @@ -8,9 +8,7 @@ import ( log "github.com/hashicorp/go-hclog" memdb "github.com/hashicorp/go-memdb" - multierror "github.com/hashicorp/go-multierror" version "github.com/hashicorp/go-version" - cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/scheduler" @@ -711,212 +709,30 @@ func allocGCEligible(a *structs.Allocation, job *structs.Job, gcTime time.Time, return timeDiff > interval.Nanoseconds() } +// TODO: we need a periodic trigger to iterate over all the volumes and split +// them up into separate work items, same as we do for jobs. + // csiVolumeClaimGC is used to garbage collect CSI volume claims func (c *CoreScheduler) csiVolumeClaimGC(eval *structs.Evaluation) error { - c.logger.Trace("garbage collecting unclaimed CSI volume claims") + c.logger.Trace("garbage collecting unclaimed CSI volume claims", "eval.JobID", eval.JobID) // Volume ID smuggled in with the eval's own JobID evalVolID := strings.Split(eval.JobID, ":") - if len(evalVolID) != 3 { + + // COMPAT(1.0): 0.11.0 shipped with 3 fields. tighten this check to len == 2 + if len(evalVolID) < 2 { c.logger.Error("volume gc called without volID") return nil } volID := evalVolID[1] - runningAllocs := evalVolID[2] == "purge" - return volumeClaimReap(c.srv, volID, eval.Namespace, - c.srv.config.Region, eval.LeaderACL, runningAllocs) -} - -func volumeClaimReap(srv RPCServer, volID, namespace, region, leaderACL string, runningAllocs bool) error { - - ws := memdb.NewWatchSet() - - vol, err := srv.State().CSIVolumeByID(ws, namespace, volID) - if err != nil { - return err - } - if vol == nil { - return nil - } - vol, err = srv.State().CSIVolumeDenormalize(ws, vol) - if err != nil { - return err - } - - plug, err := srv.State().CSIPluginByID(ws, vol.PluginID) - if err != nil { - return err - } - - nodeClaims := collectClaimsToGCImpl(vol, runningAllocs) - - var result *multierror.Error - for _, claim := range vol.PastClaims { - nodeClaims, err = volumeClaimReapImpl(srv, - &volumeClaimReapArgs{ - vol: vol, - plug: plug, - claim: claim, - namespace: namespace, - region: region, - leaderACL: leaderACL, - nodeClaims: nodeClaims, - }, - ) - if err != nil { - result = multierror.Append(result, err) - continue - } + req := &structs.CSIVolumeClaimRequest{ + VolumeID: volID, + Claim: structs.CSIVolumeClaimRelease, } - return result.ErrorOrNil() - -} + req.Namespace = eval.Namespace + req.Region = c.srv.config.Region -func collectClaimsToGCImpl(vol *structs.CSIVolume, runningAllocs bool) map[string]int { - nodeClaims := map[string]int{} // node IDs -> count - - collectFunc := func(allocs map[string]*structs.Allocation, - claims map[string]*structs.CSIVolumeClaim) { - - for allocID, alloc := range allocs { - claim, ok := claims[allocID] - if !ok { - // COMPAT(1.0): the CSIVolumeClaim fields were added - // after 0.11.1, so claims made before that may be - // missing this value. note that we'll have non-nil - // allocs here because we called denormalize on the - // value. - claim = &structs.CSIVolumeClaim{ - AllocationID: allocID, - NodeID: alloc.NodeID, - State: structs.CSIVolumeClaimStateTaken, - } - } - nodeClaims[claim.NodeID]++ - if runningAllocs || alloc.Terminated() { - // only overwrite the PastClaim if this is new, - // so that we can track state between subsequent calls - if _, exists := vol.PastClaims[claim.AllocationID]; !exists { - claim.State = structs.CSIVolumeClaimStateTaken - vol.PastClaims[claim.AllocationID] = claim - } - } - } - } - - collectFunc(vol.WriteAllocs, vol.WriteClaims) - collectFunc(vol.ReadAllocs, vol.ReadClaims) - return nodeClaims -} - -type volumeClaimReapArgs struct { - vol *structs.CSIVolume - plug *structs.CSIPlugin - claim *structs.CSIVolumeClaim - region string - namespace string - leaderACL string - nodeClaims map[string]int // node IDs -> count -} - -func volumeClaimReapImpl(srv RPCServer, args *volumeClaimReapArgs) (map[string]int, error) { - vol := args.vol - claim := args.claim - - var err error - var nReq *cstructs.ClientCSINodeDetachVolumeRequest - - checkpoint := func(claimState structs.CSIVolumeClaimState) error { - req := &structs.CSIVolumeClaimRequest{ - VolumeID: vol.ID, - AllocationID: claim.AllocationID, - Claim: structs.CSIVolumeClaimRelease, - WriteRequest: structs.WriteRequest{ - Region: args.region, - Namespace: args.namespace, - AuthToken: args.leaderACL, - }, - } - return srv.RPC("CSIVolume.Claim", req, &structs.CSIVolumeClaimResponse{}) - } - - // previous checkpoints may have set the past claim state already. - // in practice we should never see CSIVolumeClaimStateControllerDetached - // but having an option for the state makes it easy to add a checkpoint - // in a backwards compatible way if we need one later - switch claim.State { - case structs.CSIVolumeClaimStateNodeDetached: - goto NODE_DETACHED - case structs.CSIVolumeClaimStateControllerDetached: - goto RELEASE_CLAIM - case structs.CSIVolumeClaimStateReadyToFree: - goto RELEASE_CLAIM - } - - // (1) NodePublish / NodeUnstage must be completed before controller - // operations or releasing the claim. - nReq = &cstructs.ClientCSINodeDetachVolumeRequest{ - PluginID: args.plug.ID, - VolumeID: vol.ID, - ExternalID: vol.RemoteID(), - AllocID: claim.AllocationID, - NodeID: claim.NodeID, - AttachmentMode: vol.AttachmentMode, - AccessMode: vol.AccessMode, - ReadOnly: claim.Mode == structs.CSIVolumeClaimRead, - } - err = srv.RPC("ClientCSI.NodeDetachVolume", nReq, - &cstructs.ClientCSINodeDetachVolumeResponse{}) - if err != nil { - return args.nodeClaims, err - } - err = checkpoint(structs.CSIVolumeClaimStateNodeDetached) - if err != nil { - return args.nodeClaims, err - } - -NODE_DETACHED: - args.nodeClaims[claim.NodeID]-- - - // (2) we only emit the controller unpublish if no other allocs - // on the node need it, but we also only want to make this - // call at most once per node - if vol.ControllerRequired && args.nodeClaims[claim.NodeID] < 1 { - - // we need to get the CSI Node ID, which is not the same as - // the Nomad Node ID - ws := memdb.NewWatchSet() - targetNode, err := srv.State().NodeByID(ws, claim.NodeID) - if err != nil { - return args.nodeClaims, err - } - if targetNode == nil { - return args.nodeClaims, fmt.Errorf("%s: %s", - structs.ErrUnknownNodePrefix, claim.NodeID) - } - targetCSIInfo, ok := targetNode.CSINodePlugins[args.plug.ID] - if !ok { - return args.nodeClaims, fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID) - } - - cReq := &cstructs.ClientCSIControllerDetachVolumeRequest{ - VolumeID: vol.RemoteID(), - ClientCSINodeID: targetCSIInfo.NodeInfo.ID, - } - cReq.PluginID = args.plug.ID - err = srv.RPC("ClientCSI.ControllerDetachVolume", cReq, - &cstructs.ClientCSIControllerDetachVolumeResponse{}) - if err != nil { - return args.nodeClaims, err - } - } - -RELEASE_CLAIM: - // (3) release the claim from the state store, allowing it to be rescheduled - err = checkpoint(structs.CSIVolumeClaimStateReadyToFree) - if err != nil { - return args.nodeClaims, err - } - return args.nodeClaims, nil + err := c.srv.RPC("CSIVolume.Claim", req, &structs.CSIVolumeClaimResponse{}) + return err } diff --git a/nomad/core_sched_test.go b/nomad/core_sched_test.go index 819e0908ddbc..70b500a82bcf 100644 --- a/nomad/core_sched_test.go +++ b/nomad/core_sched_test.go @@ -6,10 +6,8 @@ import ( "time" memdb "github.com/hashicorp/go-memdb" - cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/mock" - "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/assert" @@ -2195,268 +2193,3 @@ func TestAllocation_GCEligible(t *testing.T) { alloc.ClientStatus = structs.AllocClientStatusComplete require.True(allocGCEligible(alloc, nil, time.Now(), 1000)) } - -func TestCSI_GCVolumeClaims_Collection(t *testing.T) { - t.Parallel() - srv, shutdownSrv := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) - defer shutdownSrv() - testutil.WaitForLeader(t, srv.RPC) - - state := srv.fsm.State() - ws := memdb.NewWatchSet() - index := uint64(100) - - // Create a client node, plugin, and volume - node := mock.Node() - node.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early version - node.CSINodePlugins = map[string]*structs.CSIInfo{ - "csi-plugin-example": { - PluginID: "csi-plugin-example", - Healthy: true, - RequiresControllerPlugin: true, - NodeInfo: &structs.CSINodeInfo{}, - }, - } - node.CSIControllerPlugins = map[string]*structs.CSIInfo{ - "csi-plugin-example": { - PluginID: "csi-plugin-example", - Healthy: true, - RequiresControllerPlugin: true, - ControllerInfo: &structs.CSIControllerInfo{ - SupportsReadOnlyAttach: true, - SupportsAttachDetach: true, - SupportsListVolumes: true, - SupportsListVolumesAttachedNodes: false, - }, - }, - } - err := state.UpsertNode(99, node) - require.NoError(t, err) - volId0 := uuid.Generate() - ns := structs.DefaultNamespace - vols := []*structs.CSIVolume{{ - ID: volId0, - Namespace: ns, - PluginID: "csi-plugin-example", - AccessMode: structs.CSIVolumeAccessModeMultiNodeSingleWriter, - AttachmentMode: structs.CSIVolumeAttachmentModeFilesystem, - }} - - err = state.CSIVolumeRegister(index, vols) - index++ - require.NoError(t, err) - vol, err := state.CSIVolumeByID(ws, ns, volId0) - - require.NoError(t, err) - require.True(t, vol.ControllerRequired) - require.Len(t, vol.ReadAllocs, 0) - require.Len(t, vol.WriteAllocs, 0) - - // Create a job with 2 allocations - job := mock.Job() - job.TaskGroups[0].Volumes = map[string]*structs.VolumeRequest{ - "_": { - Name: "someVolume", - Type: structs.VolumeTypeCSI, - Source: volId0, - ReadOnly: false, - }, - } - err = state.UpsertJob(index, job) - index++ - require.NoError(t, err) - - alloc1 := mock.Alloc() - alloc1.JobID = job.ID - alloc1.NodeID = node.ID - err = state.UpsertJobSummary(index, mock.JobSummary(alloc1.JobID)) - index++ - require.NoError(t, err) - alloc1.TaskGroup = job.TaskGroups[0].Name - - alloc2 := mock.Alloc() - alloc2.JobID = job.ID - alloc2.NodeID = node.ID - err = state.UpsertJobSummary(index, mock.JobSummary(alloc2.JobID)) - index++ - require.NoError(t, err) - alloc2.TaskGroup = job.TaskGroups[0].Name - - err = state.UpsertAllocs(104, []*structs.Allocation{alloc1, alloc2}) - require.NoError(t, err) - - // Claim the volumes and verify the claims were set - err = state.CSIVolumeClaim(index, ns, volId0, &structs.CSIVolumeClaim{ - AllocationID: alloc1.ID, - NodeID: alloc1.NodeID, - Mode: structs.CSIVolumeClaimWrite, - }) - index++ - require.NoError(t, err) - - err = state.CSIVolumeClaim(index, ns, volId0, &structs.CSIVolumeClaim{ - AllocationID: alloc2.ID, - NodeID: alloc2.NodeID, - Mode: structs.CSIVolumeClaimRead, - }) - index++ - require.NoError(t, err) - - vol, err = state.CSIVolumeByID(ws, ns, volId0) - require.NoError(t, err) - require.Len(t, vol.ReadAllocs, 1) - require.Len(t, vol.WriteAllocs, 1) - - // Update both allocs as failed/terminated - alloc1.ClientStatus = structs.AllocClientStatusFailed - alloc2.ClientStatus = structs.AllocClientStatusFailed - err = state.UpdateAllocsFromClient(index, []*structs.Allocation{alloc1, alloc2}) - require.NoError(t, err) - - vol, err = state.CSIVolumeDenormalize(ws, vol) - require.NoError(t, err) - - nodeClaims := collectClaimsToGCImpl(vol, false) - require.Equal(t, nodeClaims[node.ID], 2) - require.Len(t, vol.PastClaims, 2) -} - -func TestCSI_GCVolumeClaims_Reap(t *testing.T) { - t.Parallel() - require := require.New(t) - - s, shutdownSrv := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) - defer shutdownSrv() - testutil.WaitForLeader(t, s.RPC) - - node := mock.Node() - plugin := mock.CSIPlugin() - vol := mock.CSIVolume(plugin) - alloc := mock.Alloc() - - cases := []struct { - Name string - ClaimsCount map[string]int - ControllerRequired bool - ExpectedErr string - ExpectedCount int - ExpectedClaimsCount int - ExpectedNodeDetachVolumeCount int - ExpectedControllerDetachVolumeCount int - ExpectedVolumeClaimCount int - srv *MockRPCServer - }{ - { - Name: "NodeDetachVolume fails", - ClaimsCount: map[string]int{node.ID: 1}, - ControllerRequired: true, - ExpectedErr: "node plugin missing", - ExpectedClaimsCount: 1, - ExpectedNodeDetachVolumeCount: 1, - srv: &MockRPCServer{ - state: s.State(), - nextCSINodeDetachVolumeError: fmt.Errorf("node plugin missing"), - }, - }, - { - Name: "ControllerDetachVolume no controllers", - ClaimsCount: map[string]int{node.ID: 1}, - ControllerRequired: true, - ExpectedErr: fmt.Sprintf("Unknown node: %s", node.ID), - ExpectedClaimsCount: 0, - ExpectedNodeDetachVolumeCount: 1, - ExpectedControllerDetachVolumeCount: 0, - ExpectedVolumeClaimCount: 1, - srv: &MockRPCServer{ - state: s.State(), - }, - }, - { - Name: "ControllerDetachVolume node-only", - ClaimsCount: map[string]int{node.ID: 1}, - ControllerRequired: false, - ExpectedClaimsCount: 0, - ExpectedNodeDetachVolumeCount: 1, - ExpectedControllerDetachVolumeCount: 0, - ExpectedVolumeClaimCount: 2, - srv: &MockRPCServer{ - state: s.State(), - }, - }, - } - - for _, tc := range cases { - t.Run(tc.Name, func(t *testing.T) { - vol.ControllerRequired = tc.ControllerRequired - claim := &structs.CSIVolumeClaim{ - AllocationID: alloc.ID, - NodeID: node.ID, - State: structs.CSIVolumeClaimStateTaken, - Mode: structs.CSIVolumeClaimRead, - } - nodeClaims, err := volumeClaimReapImpl(tc.srv, &volumeClaimReapArgs{ - vol: vol, - plug: plugin, - claim: claim, - region: "global", - namespace: "default", - leaderACL: "not-in-use", - nodeClaims: tc.ClaimsCount, - }) - if tc.ExpectedErr != "" { - require.EqualError(err, tc.ExpectedErr) - } else { - require.NoError(err) - } - require.Equal(tc.ExpectedClaimsCount, - nodeClaims[claim.NodeID], "expected claims remaining") - require.Equal(tc.ExpectedNodeDetachVolumeCount, - tc.srv.countCSINodeDetachVolume, "node detach RPC count") - require.Equal(tc.ExpectedControllerDetachVolumeCount, - tc.srv.countCSIControllerDetachVolume, "controller detach RPC count") - require.Equal(tc.ExpectedVolumeClaimCount, - tc.srv.countCSIVolumeClaim, "volume claim RPC count") - }) - } -} - -type MockRPCServer struct { - state *state.StateStore - - // mock responses for ClientCSI.NodeDetachVolume - nextCSINodeDetachVolumeResponse *cstructs.ClientCSINodeDetachVolumeResponse - nextCSINodeDetachVolumeError error - countCSINodeDetachVolume int - - // mock responses for ClientCSI.ControllerDetachVolume - nextCSIControllerDetachVolumeResponse *cstructs.ClientCSIControllerDetachVolumeResponse - nextCSIControllerDetachVolumeError error - countCSIControllerDetachVolume int - - // mock responses for CSI.VolumeClaim - nextCSIVolumeClaimResponse *structs.CSIVolumeClaimResponse - nextCSIVolumeClaimError error - countCSIVolumeClaim int -} - -func (srv *MockRPCServer) RPC(method string, args interface{}, reply interface{}) error { - switch method { - case "ClientCSI.NodeDetachVolume": - reply = srv.nextCSINodeDetachVolumeResponse - srv.countCSINodeDetachVolume++ - return srv.nextCSINodeDetachVolumeError - case "ClientCSI.ControllerDetachVolume": - reply = srv.nextCSIControllerDetachVolumeResponse - srv.countCSIControllerDetachVolume++ - return srv.nextCSIControllerDetachVolumeError - case "CSIVolume.Claim": - reply = srv.nextCSIVolumeClaimResponse - srv.countCSIVolumeClaim++ - return srv.nextCSIVolumeClaimError - default: - return fmt.Errorf("unexpected method %q passed to mock", method) - } - -} - -func (srv *MockRPCServer) State() *state.StateStore { return srv.state } diff --git a/nomad/fsm.go b/nomad/fsm.go index 9ec1ef086510..b9f412393dd7 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -270,6 +270,8 @@ func (n *nomadFSM) Apply(log *raft.Log) interface{} { return n.applyCSIVolumeDeregister(buf[1:], log.Index) case structs.CSIVolumeClaimRequestType: return n.applyCSIVolumeClaim(buf[1:], log.Index) + case structs.CSIVolumeClaimBatchRequestType: + return n.applyCSIVolumeBatchClaim(buf[1:], log.Index) case structs.ScalingEventRegisterRequestType: return n.applyUpsertScalingEvent(buf[1:], log.Index) } @@ -1156,33 +1158,35 @@ func (n *nomadFSM) applyCSIVolumeDeregister(buf []byte, index uint64) interface{ return nil } -func (n *nomadFSM) applyCSIVolumeClaim(buf []byte, index uint64) interface{} { - var req structs.CSIVolumeClaimRequest - if err := structs.Decode(buf, &req); err != nil { +func (n *nomadFSM) applyCSIVolumeBatchClaim(buf []byte, index uint64) interface{} { + var batch *structs.CSIVolumeClaimBatchRequest + if err := structs.Decode(buf, &batch); err != nil { panic(fmt.Errorf("failed to decode request: %v", err)) } - defer metrics.MeasureSince([]string{"nomad", "fsm", "apply_csi_volume_claim"}, time.Now()) + defer metrics.MeasureSince([]string{"nomad", "fsm", "apply_csi_volume_batch_claim"}, time.Now()) - ws := memdb.NewWatchSet() - alloc, err := n.state.AllocByID(ws, req.AllocationID) - if err != nil { - n.logger.Error("AllocByID failed", "error", err) - return err - } - if alloc == nil { - n.logger.Error("AllocByID failed to find alloc", "alloc_id", req.AllocationID) + for _, req := range batch.Claims { + err := n.state.CSIVolumeClaim(index, req.RequestNamespace(), + req.VolumeID, req.ToClaim()) if err != nil { - return err + n.logger.Error("CSIVolumeClaim for batch failed", "error", err) + return err // note: fails the remaining batch } + } + return nil +} - return structs.ErrUnknownAllocationPrefix +func (n *nomadFSM) applyCSIVolumeClaim(buf []byte, index uint64) interface{} { + var req structs.CSIVolumeClaimRequest + if err := structs.Decode(buf, &req); err != nil { + panic(fmt.Errorf("failed to decode request: %v", err)) } + defer metrics.MeasureSince([]string{"nomad", "fsm", "apply_csi_volume_claim"}, time.Now()) if err := n.state.CSIVolumeClaim(index, req.RequestNamespace(), req.VolumeID, req.ToClaim()); err != nil { n.logger.Error("CSIVolumeClaim failed", "error", err) return err } - return nil } diff --git a/nomad/interfaces.go b/nomad/interfaces.go deleted file mode 100644 index 4dc266d8b808..000000000000 --- a/nomad/interfaces.go +++ /dev/null @@ -1,11 +0,0 @@ -package nomad - -import "github.com/hashicorp/nomad/nomad/state" - -// RPCServer is a minimal interface of the Server, intended as -// an aid for testing logic surrounding server-to-server or -// server-to-client RPC calls -type RPCServer interface { - RPC(method string, args interface{}, reply interface{}) error - State() *state.StateStore -} diff --git a/nomad/job_endpoint.go b/nomad/job_endpoint.go index 7d4e2cce905d..a564efb33412 100644 --- a/nomad/job_endpoint.go +++ b/nomad/job_endpoint.go @@ -737,19 +737,13 @@ func (j *Job) Deregister(args *structs.JobDeregisterRequest, reply *structs.JobD for _, vol := range volumesToGC { // we have to build this eval by hand rather than calling srv.CoreJob // here because we need to use the volume's namespace - - runningAllocs := ":ok" - if args.Purge { - runningAllocs = ":purge" - } - eval := &structs.Evaluation{ ID: uuid.Generate(), Namespace: job.Namespace, Priority: structs.CoreJobPriority, Type: structs.JobTypeCore, TriggeredBy: structs.EvalTriggerAllocStop, - JobID: structs.CoreJobCSIVolumeClaimGC + ":" + vol.Source + runningAllocs, + JobID: structs.CoreJobCSIVolumeClaimGC + ":" + vol.Source, LeaderACL: j.srv.getLeaderAcl(), Status: structs.EvalStatusPending, CreateTime: now, diff --git a/nomad/leader.go b/nomad/leader.go index b43d4abd2f53..29550dc79925 100644 --- a/nomad/leader.go +++ b/nomad/leader.go @@ -241,6 +241,9 @@ func (s *Server) establishLeadership(stopCh chan struct{}) error { // Enable the NodeDrainer s.nodeDrainer.SetEnabled(true, s.State()) + // Enable the volume watcher, since we are now the leader + s.volumeWatcher.SetEnabled(true, s.State()) + // Restore the eval broker state if err := s.restoreEvals(); err != nil { return err @@ -870,6 +873,9 @@ func (s *Server) revokeLeadership() error { // Disable the node drainer s.nodeDrainer.SetEnabled(false, nil) + // Disable the volume watcher + s.volumeWatcher.SetEnabled(false, nil) + // Disable any enterprise systems required. if err := s.revokeEnterpriseLeadership(); err != nil { return err diff --git a/nomad/node_endpoint.go b/nomad/node_endpoint.go index fcfbcfcc2326..7308c00aa63c 100644 --- a/nomad/node_endpoint.go +++ b/nomad/node_endpoint.go @@ -1149,7 +1149,7 @@ func (n *Node) UpdateAlloc(args *structs.AllocUpdateRequest, reply *structs.Gene Priority: structs.CoreJobPriority, Type: structs.JobTypeCore, TriggeredBy: structs.EvalTriggerAllocStop, - JobID: structs.CoreJobCSIVolumeClaimGC + ":" + volAndNamespace[0] + ":no", + JobID: structs.CoreJobCSIVolumeClaimGC + ":" + volAndNamespace[0], LeaderACL: n.srv.getLeaderAcl(), Status: structs.EvalStatusPending, CreateTime: now.UTC().UnixNano(), diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index c1d54ebc8491..e687614409a6 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -2414,7 +2414,7 @@ func TestClientEndpoint_UpdateAlloc_UnclaimVolumes(t *testing.T) { // Verify the eval for the claim GC was emitted // Lookup the evaluations - eval, err := state.EvalsByJob(ws, job.Namespace, structs.CoreJobCSIVolumeClaimGC+":"+volId0+":no") + eval, err := state.EvalsByJob(ws, job.Namespace, structs.CoreJobCSIVolumeClaimGC+":"+volId0) require.NotNil(t, eval) require.Nil(t, err) } diff --git a/nomad/server.go b/nomad/server.go index d691faeda801..8a1353f985d9 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -35,6 +35,7 @@ import ( "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" + "github.com/hashicorp/nomad/nomad/volumewatcher" "github.com/hashicorp/nomad/scheduler" "github.com/hashicorp/raft" raftboltdb "github.com/hashicorp/raft-boltdb" @@ -186,6 +187,9 @@ type Server struct { // nodeDrainer is used to drain allocations from nodes. nodeDrainer *drainer.NodeDrainer + // volumeWatcher is used to release volume claims + volumeWatcher *volumewatcher.Watcher + // evalBroker is used to manage the in-progress evaluations // that are waiting to be brokered to a sub-scheduler evalBroker *EvalBroker @@ -399,6 +403,12 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulACLs consu return nil, fmt.Errorf("failed to create deployment watcher: %v", err) } + // Setup the volume watcher + if err := s.setupVolumeWatcher(); err != nil { + s.logger.Error("failed to create volume watcher", "error", err) + return nil, fmt.Errorf("failed to create volume watcher: %v", err) + } + // Setup the node drainer. s.setupNodeDrainer() @@ -993,6 +1003,27 @@ func (s *Server) setupDeploymentWatcher() error { return nil } +// setupVolumeWatcher creates a volume watcher that consumes the RPC +// endpoints for state information and makes transitions via Raft through a +// shim that provides the appropriate methods. +func (s *Server) setupVolumeWatcher() error { + + // Create the raft shim type to restrict the set of raft methods that can be + // made + raftShim := &volumeWatcherRaftShim{ + apply: s.raftApply, + } + + // Create the volume watcher + s.volumeWatcher = volumewatcher.NewVolumesWatcher( + s.logger, raftShim, + s.staticEndpoints.ClientCSI, + volumewatcher.LimitStateQueriesPerSecond, + volumewatcher.CrossVolumeUpdateBatchDuration) + + return nil +} + // setupNodeDrainer creates a node drainer which will be enabled when a server // becomes a leader. func (s *Server) setupNodeDrainer() { diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 00d418a7249f..623ea55c795c 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -2068,9 +2068,14 @@ func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *s return err } - err = volume.Claim(claim, alloc) - if err != nil { - return err + // in the case of a job deregistration, there will be no allocation ID + // for the claim but we still want to write an updated index to the volume + // so that volume reaping is triggered + if claim.AllocationID != "" { + err = volume.Claim(claim, alloc) + if err != nil { + return err + } } volume.ModifyIndex = index diff --git a/nomad/structs/csi.go b/nomad/structs/csi.go index bea3439ea685..5428f89ef691 100644 --- a/nomad/structs/csi.go +++ b/nomad/structs/csi.go @@ -575,6 +575,10 @@ const ( CSIVolumeClaimRelease ) +type CSIVolumeClaimBatchRequest struct { + Claims []CSIVolumeClaimRequest +} + type CSIVolumeClaimRequest struct { VolumeID string AllocationID string diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index ce36d15fb41e..8f3eb060f76f 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -90,6 +90,7 @@ const ( CSIVolumeRegisterRequestType CSIVolumeDeregisterRequestType CSIVolumeClaimRequestType + CSIVolumeClaimBatchRequestType ScalingEventRegisterRequestType ) diff --git a/nomad/volumewatcher/batcher.go b/nomad/volumewatcher/batcher.go new file mode 100644 index 000000000000..a67ef1bb8d50 --- /dev/null +++ b/nomad/volumewatcher/batcher.go @@ -0,0 +1,125 @@ +package volumewatcher + +import ( + "context" + "time" + + "github.com/hashicorp/nomad/nomad/structs" +) + +// VolumeUpdateBatcher is used to batch the updates for volume claims +type VolumeUpdateBatcher struct { + // batch is the batching duration + batch time.Duration + + // raft is used to actually commit the updates + raft VolumeRaftEndpoints + + // workCh is used to pass evaluations to the daemon process + workCh chan *updateWrapper + + // ctx is used to exit the daemon batcher + ctx context.Context +} + +// NewVolumeUpdateBatcher returns an VolumeUpdateBatcher that uses the +// passed raft endpoints to create the updates to volume claims, and +// exits the batcher when the passed exit channel is closed. +func NewVolumeUpdateBatcher(batchDuration time.Duration, raft VolumeRaftEndpoints, ctx context.Context) *VolumeUpdateBatcher { + b := &VolumeUpdateBatcher{ + batch: batchDuration, + raft: raft, + ctx: ctx, + workCh: make(chan *updateWrapper, 10), + } + + go b.batcher() + return b +} + +// CreateUpdate batches the volume claim update and returns a future +// that tracks the completion of the request. +func (b *VolumeUpdateBatcher) CreateUpdate(claims []structs.CSIVolumeClaimRequest) *BatchFuture { + wrapper := &updateWrapper{ + claims: claims, + f: make(chan *BatchFuture, 1), + } + + b.workCh <- wrapper + return <-wrapper.f +} + +type updateWrapper struct { + claims []structs.CSIVolumeClaimRequest + f chan *BatchFuture +} + +// batcher is the long lived batcher goroutine +func (b *VolumeUpdateBatcher) batcher() { + var timerCh <-chan time.Time + claims := make(map[string]structs.CSIVolumeClaimRequest) + future := NewBatchFuture() + for { + select { + case <-b.ctx.Done(): + // note: we can't flush here because we're likely no + // longer the leader + return + case w := <-b.workCh: + if timerCh == nil { + timerCh = time.After(b.batch) + } + + // de-dupe and store the claim update, and attach the future + for _, upd := range w.claims { + claims[upd.VolumeID+upd.RequestNamespace()] = upd + } + w.f <- future + case <-timerCh: + // Capture the future and create a new one + f := future + future = NewBatchFuture() + + // Create the batch request + req := structs.CSIVolumeClaimBatchRequest{} + for _, claim := range claims { + req.Claims = append(req.Claims, claim) + } + + // Upsert the claims in a go routine + go f.Set(b.raft.UpsertVolumeClaims(&req)) + + // Reset the claims list and timer + claims = make(map[string]structs.CSIVolumeClaimRequest) + timerCh = nil + } + } +} + +// BatchFuture is a future that can be used to retrieve the index for +// the update or any error in the update process +type BatchFuture struct { + index uint64 + err error + waitCh chan struct{} +} + +// NewBatchFuture returns a new BatchFuture +func NewBatchFuture() *BatchFuture { + return &BatchFuture{ + waitCh: make(chan struct{}), + } +} + +// Set sets the results of the future, unblocking any client. +func (f *BatchFuture) Set(index uint64, err error) { + f.index = index + f.err = err + close(f.waitCh) +} + +// Results returns the creation index and any error. +func (f *BatchFuture) Results() (uint64, error) { + <-f.waitCh + return f.index, f.err +} diff --git a/nomad/volumewatcher/batcher_test.go b/nomad/volumewatcher/batcher_test.go new file mode 100644 index 000000000000..7f9915d193b7 --- /dev/null +++ b/nomad/volumewatcher/batcher_test.go @@ -0,0 +1,85 @@ +package volumewatcher + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +// TestVolumeWatch_Batcher tests the update batching logic +func TestVolumeWatch_Batcher(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx, exitFn := context.WithCancel(context.Background()) + defer exitFn() + + srv := &MockBatchingRPCServer{} + srv.state = state.TestStateStore(t) + srv.volumeUpdateBatcher = NewVolumeUpdateBatcher(CrossVolumeUpdateBatchDuration, srv, ctx) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + + // because we wait for the results to return from the batch for each + // Watcher.updateClaims, we can't test that we're batching except across + // multiple volume watchers. create 2 volumes and their watchers here. + alloc0 := mock.Alloc() + alloc0.ClientStatus = structs.AllocClientStatusComplete + vol0 := testVolume(nil, plugin, alloc0, node.ID) + w0 := &volumeWatcher{ + v: vol0, + rpc: srv, + state: srv.State(), + updateClaims: srv.UpdateClaims, + logger: testlog.HCLogger(t), + } + + alloc1 := mock.Alloc() + alloc1.ClientStatus = structs.AllocClientStatusComplete + vol1 := testVolume(nil, plugin, alloc1, node.ID) + w1 := &volumeWatcher{ + v: vol1, + rpc: srv, + state: srv.State(), + updateClaims: srv.UpdateClaims, + logger: testlog.HCLogger(t), + } + + srv.nextCSIControllerDetachError = fmt.Errorf("some controller plugin error") + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + w0.volumeReapImpl(vol0) + wg.Done() + }() + go func() { + w1.volumeReapImpl(vol1) + wg.Done() + }() + + wg.Wait() + + require.Equal(structs.CSIVolumeClaimStateNodeDetached, vol0.PastClaims[alloc0.ID].State) + require.Equal(structs.CSIVolumeClaimStateNodeDetached, vol1.PastClaims[alloc1.ID].State) + require.Equal(2, srv.countCSINodeDetachVolume) + require.Equal(2, srv.countCSIControllerDetachVolume) + require.Equal(2, srv.countUpdateClaims) + + // note: it's technically possible that the volumeReapImpl + // goroutines get de-scheduled and we don't write both updates in + // the same batch. but this seems really unlikely, so we're + // testing for both cases here so that if we start seeing a flake + // here in the future we have a clear cause for it. + require.GreaterOrEqual(srv.countUpsertVolumeClaims, 1) + require.Equal(1, srv.countUpsertVolumeClaims) +} diff --git a/nomad/volumewatcher/interfaces.go b/nomad/volumewatcher/interfaces.go new file mode 100644 index 000000000000..55d82c55b7ce --- /dev/null +++ b/nomad/volumewatcher/interfaces.go @@ -0,0 +1,28 @@ +package volumewatcher + +import ( + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/structs" +) + +// VolumeRaftEndpoints exposes the volume watcher to a set of functions +// to apply data transforms via Raft. +type VolumeRaftEndpoints interface { + + // UpsertVolumeClaims applys a batch of claims to raft + UpsertVolumeClaims(*structs.CSIVolumeClaimBatchRequest) (uint64, error) +} + +// ClientRPC is a minimal interface of the Server, intended as an aid +// for testing logic surrounding server-to-server or server-to-client +// RPC calls and to avoid circular references between the nomad +// package and the volumewatcher +type ClientRPC interface { + ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error + NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error +} + +// claimUpdater is the function used to update claims on behalf of a volume +// (used to wrap batch updates so that we can test +// volumeWatcher methods synchronously without batching) +type updateClaimsFn func(claims []structs.CSIVolumeClaimRequest) (uint64, error) diff --git a/nomad/volumewatcher/interfaces_test.go b/nomad/volumewatcher/interfaces_test.go new file mode 100644 index 000000000000..068a76e52de1 --- /dev/null +++ b/nomad/volumewatcher/interfaces_test.go @@ -0,0 +1,148 @@ +package volumewatcher + +import ( + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" +) + +// Create a client node with plugin info +func testNode(node *structs.Node, plugin *structs.CSIPlugin, s *state.StateStore) *structs.Node { + if node != nil { + return node + } + node = mock.Node() + node.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early version + node.CSINodePlugins = map[string]*structs.CSIInfo{ + plugin.ID: { + PluginID: plugin.ID, + Healthy: true, + RequiresControllerPlugin: plugin.ControllerRequired, + NodeInfo: &structs.CSINodeInfo{}, + }, + } + if plugin.ControllerRequired { + node.CSIControllerPlugins = map[string]*structs.CSIInfo{ + plugin.ID: { + PluginID: plugin.ID, + Healthy: true, + RequiresControllerPlugin: true, + ControllerInfo: &structs.CSIControllerInfo{ + SupportsReadOnlyAttach: true, + SupportsAttachDetach: true, + SupportsListVolumes: true, + SupportsListVolumesAttachedNodes: false, + }, + }, + } + } else { + node.CSIControllerPlugins = map[string]*structs.CSIInfo{} + } + s.UpsertNode(99, node) + return node +} + +// Create a test volume with claim info +func testVolume(vol *structs.CSIVolume, plugin *structs.CSIPlugin, alloc *structs.Allocation, nodeID string) *structs.CSIVolume { + if vol != nil { + return vol + } + vol = mock.CSIVolume(plugin) + vol.ControllerRequired = plugin.ControllerRequired + + vol.ReadAllocs = map[string]*structs.Allocation{alloc.ID: alloc} + vol.ReadClaims = map[string]*structs.CSIVolumeClaim{ + alloc.ID: { + AllocationID: alloc.ID, + NodeID: nodeID, + Mode: structs.CSIVolumeClaimRead, + State: structs.CSIVolumeClaimStateTaken, + }, + } + return vol +} + +// COMPAT(1.0): the claim fields were added after 0.11.1; this +// mock and the associated test cases can be removed for 1.0 +func testOldVolume(vol *structs.CSIVolume, plugin *structs.CSIPlugin, alloc *structs.Allocation, nodeID string) *structs.CSIVolume { + if vol != nil { + return vol + } + vol = mock.CSIVolume(plugin) + vol.ControllerRequired = plugin.ControllerRequired + + vol.ReadAllocs = map[string]*structs.Allocation{alloc.ID: alloc} + return vol +} + +type MockRPCServer struct { + state *state.StateStore + + // mock responses for ClientCSI.NodeDetachVolume + nextCSINodeDetachResponse *cstructs.ClientCSINodeDetachVolumeResponse + nextCSINodeDetachError error + countCSINodeDetachVolume int + + // mock responses for ClientCSI.ControllerDetachVolume + nextCSIControllerDetachVolumeResponse *cstructs.ClientCSIControllerDetachVolumeResponse + nextCSIControllerDetachError error + countCSIControllerDetachVolume int + + countUpdateClaims int + countUpsertVolumeClaims int +} + +func (srv *MockRPCServer) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error { + reply = srv.nextCSIControllerDetachVolumeResponse + srv.countCSIControllerDetachVolume++ + return srv.nextCSIControllerDetachError +} + +func (srv *MockRPCServer) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error { + reply = srv.nextCSINodeDetachResponse + srv.countCSINodeDetachVolume++ + return srv.nextCSINodeDetachError + +} + +func (srv *MockRPCServer) UpsertVolumeClaims(*structs.CSIVolumeClaimBatchRequest) (uint64, error) { + srv.countUpsertVolumeClaims++ + return 0, nil +} + +func (srv *MockRPCServer) State() *state.StateStore { return srv.state } + +func (srv *MockRPCServer) UpdateClaims(claims []structs.CSIVolumeClaimRequest) (uint64, error) { + srv.countUpdateClaims++ + return 0, nil +} + +type MockBatchingRPCServer struct { + MockRPCServer + volumeUpdateBatcher *VolumeUpdateBatcher +} + +func (srv *MockBatchingRPCServer) UpdateClaims(claims []structs.CSIVolumeClaimRequest) (uint64, error) { + srv.countUpdateClaims++ + return srv.volumeUpdateBatcher.CreateUpdate(claims).Results() +} + +type MockStatefulRPCServer struct { + MockRPCServer + volumeUpdateBatcher *VolumeUpdateBatcher +} + +func (srv *MockStatefulRPCServer) UpsertVolumeClaims(batch *structs.CSIVolumeClaimBatchRequest) (uint64, error) { + srv.countUpsertVolumeClaims++ + index, _ := srv.state.LatestIndex() + for _, req := range batch.Claims { + index++ + err := srv.state.CSIVolumeClaim(index, req.RequestNamespace(), + req.VolumeID, req.ToClaim()) + if err != nil { + return 0, err + } + } + return index, nil +} diff --git a/nomad/volumewatcher/volume_watcher.go b/nomad/volumewatcher/volume_watcher.go new file mode 100644 index 000000000000..6177ae10d20e --- /dev/null +++ b/nomad/volumewatcher/volume_watcher.go @@ -0,0 +1,382 @@ +package volumewatcher + +import ( + "context" + "fmt" + "sync" + + log "github.com/hashicorp/go-hclog" + memdb "github.com/hashicorp/go-memdb" + multierror "github.com/hashicorp/go-multierror" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" +) + +// volumeWatcher is used to watch a single volume and trigger the +// scheduler when allocation health transitions. +type volumeWatcher struct { + // v is the volume being watched + v *structs.CSIVolume + + // state is the state that is watched for state changes. + state *state.StateStore + + // updateClaims is the function used to apply claims to raft + updateClaims updateClaimsFn + + // server interface for CSI client RPCs + rpc ClientRPC + + logger log.Logger + shutdownCtx context.Context // parent context + ctx context.Context // own context + exitFn context.CancelFunc + + // updateCh is triggered when there is an updated volume + updateCh chan *structs.CSIVolume + + wLock sync.RWMutex + running bool +} + +// newVolumeWatcher returns a volume watcher that is used to watch +// volumes +func newVolumeWatcher(parent *Watcher, vol *structs.CSIVolume) *volumeWatcher { + + w := &volumeWatcher{ + updateCh: make(chan *structs.CSIVolume, 1), + updateClaims: parent.updateClaims, + v: vol, + state: parent.state, + rpc: parent.rpc, + logger: parent.logger.With("volume_id", vol.ID, "namespace", vol.Namespace), + shutdownCtx: parent.ctx, + } + + // Start the long lived watcher that scans for allocation updates + w.Start() + return w +} + +// Notify signals an update to the tracked volume. +func (vw *volumeWatcher) Notify(v *structs.CSIVolume) { + if !vw.isRunning() { + vw.Start() + } + select { + case vw.updateCh <- v: + case <-vw.shutdownCtx.Done(): // prevent deadlock if we stopped + case <-vw.ctx.Done(): // prevent deadlock if we stopped + } +} + +func (vw *volumeWatcher) Start() { + vw.logger.Trace("starting watcher", "id", vw.v.ID, "namespace", vw.v.Namespace) + vw.wLock.Lock() + defer vw.wLock.Unlock() + vw.running = true + ctx, exitFn := context.WithCancel(vw.shutdownCtx) + vw.ctx = ctx + vw.exitFn = exitFn + go vw.watch() +} + +// Stop stops watching the volume. This should be called whenever a +// volume's claims are fully reaped or the watcher is no longer needed. +func (vw *volumeWatcher) Stop() { + vw.logger.Trace("no more claims", "id", vw.v.ID, "namespace", vw.v.Namespace) + vw.exitFn() +} + +func (vw *volumeWatcher) isRunning() bool { + vw.wLock.RLock() + defer vw.wLock.RUnlock() + select { + case <-vw.shutdownCtx.Done(): + return false + case <-vw.ctx.Done(): + return false + default: + return vw.running + } +} + +// watch is the long-running function that watches for changes to a volume. +// Each pass steps the volume's claims through the various states of reaping +// until the volume has no more claims eligible to be reaped. +func (vw *volumeWatcher) watch() { + for { + select { + // TODO(tgross): currently server->client RPC have no cancellation + // context, so we can't stop the long-runner RPCs gracefully + case <-vw.shutdownCtx.Done(): + return + case <-vw.ctx.Done(): + return + case vol := <-vw.updateCh: + // while we won't make raft writes if we get a stale update, + // we can still fire extra CSI RPC calls if we don't check this + if vol == nil || vw.v == nil || vol.ModifyIndex >= vw.v.ModifyIndex { + vol = vw.getVolume(vol) + if vol == nil { + return + } + vw.volumeReap(vol) + } + } + } +} + +// getVolume returns the tracked volume, fully populated with the current +// state +func (vw *volumeWatcher) getVolume(vol *structs.CSIVolume) *structs.CSIVolume { + vw.wLock.RLock() + defer vw.wLock.RUnlock() + + var err error + ws := memdb.NewWatchSet() + + vol, err = vw.state.CSIVolumeDenormalizePlugins(ws, vol.Copy()) + if err != nil { + vw.logger.Error("could not query plugins for volume", "error", err) + return nil + } + + vol, err = vw.state.CSIVolumeDenormalize(ws, vol) + if err != nil { + vw.logger.Error("could not query allocs for volume", "error", err) + return nil + } + vw.v = vol + return vol +} + +// volumeReap collects errors for logging but doesn't return them +// to the main loop. +func (vw *volumeWatcher) volumeReap(vol *structs.CSIVolume) { + vw.logger.Trace("releasing unused volume claims", "id", vol.ID, "namespace", vol.Namespace) + err := vw.volumeReapImpl(vol) + if err != nil { + vw.logger.Error("error releasing volume claims", "error", err) + } + if vw.isUnclaimed(vol) { + vw.Stop() + } +} + +func (vw *volumeWatcher) isUnclaimed(vol *structs.CSIVolume) bool { + return len(vol.ReadClaims) == 0 && len(vol.WriteClaims) == 0 && len(vol.PastClaims) == 0 +} + +func (vw *volumeWatcher) volumeReapImpl(vol *structs.CSIVolume) error { + var result *multierror.Error + nodeClaims := map[string]int{} // node IDs -> count + jobs := map[string]bool{} // jobID -> stopped + + // if a job is purged, the subsequent alloc updates can't + // trigger a GC job because there's no job for them to query. + // Job.Deregister will send a claim release on all claims + // but the allocs will not yet be terminated. save the status + // for each job so that we don't requery in this pass + checkStopped := func(jobID string) bool { + namespace := vw.v.Namespace + isStopped, ok := jobs[jobID] + if !ok { + ws := memdb.NewWatchSet() + job, err := vw.state.JobByID(ws, namespace, jobID) + if err != nil { + isStopped = true + } + if job == nil || job.Stopped() { + isStopped = true + } + jobs[jobID] = isStopped + } + return isStopped + } + + collect := func(allocs map[string]*structs.Allocation, + claims map[string]*structs.CSIVolumeClaim) { + + for allocID, alloc := range allocs { + + if alloc == nil { + _, exists := vol.PastClaims[allocID] + if !exists { + vol.PastClaims[allocID] = &structs.CSIVolumeClaim{ + AllocationID: allocID, + State: structs.CSIVolumeClaimStateReadyToFree, + } + } + continue + } + + nodeClaims[alloc.NodeID]++ + + if alloc.Terminated() || checkStopped(alloc.JobID) { + // don't overwrite the PastClaim if we've seen it before, + // so that we can track state between subsequent calls + _, exists := vol.PastClaims[allocID] + if !exists { + claim, ok := claims[allocID] + if !ok { + claim = &structs.CSIVolumeClaim{ + AllocationID: allocID, + NodeID: alloc.NodeID, + } + } + claim.State = structs.CSIVolumeClaimStateTaken + vol.PastClaims[allocID] = claim + } + } + } + } + + collect(vol.ReadAllocs, vol.ReadClaims) + collect(vol.WriteAllocs, vol.WriteClaims) + + if len(vol.PastClaims) == 0 { + return nil + } + + for _, claim := range vol.PastClaims { + + var err error + + // previous checkpoints may have set the past claim state already. + // in practice we should never see CSIVolumeClaimStateControllerDetached + // but having an option for the state makes it easy to add a checkpoint + // in a backwards compatible way if we need one later + switch claim.State { + case structs.CSIVolumeClaimStateNodeDetached: + goto NODE_DETACHED + case structs.CSIVolumeClaimStateControllerDetached: + goto RELEASE_CLAIM + case structs.CSIVolumeClaimStateReadyToFree: + goto RELEASE_CLAIM + } + + err = vw.nodeDetach(vol, claim) + if err != nil { + result = multierror.Append(result, err) + break + } + + NODE_DETACHED: + nodeClaims[claim.NodeID]-- + err = vw.controllerDetach(vol, claim, nodeClaims) + if err != nil { + result = multierror.Append(result, err) + break + } + + RELEASE_CLAIM: + err = vw.checkpoint(vol, claim) + if err != nil { + result = multierror.Append(result, err) + break + } + // the checkpoint deletes from the state store, but this operates + // on our local copy which aids in testing + delete(vol.PastClaims, claim.AllocationID) + } + + return result.ErrorOrNil() + +} + +// nodeDetach makes the client NodePublish / NodeUnstage RPCs, which +// must be completed before controller operations or releasing the claim. +func (vw *volumeWatcher) nodeDetach(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim) error { + vw.logger.Trace("detaching node", "id", vol.ID, "namespace", vol.Namespace) + nReq := &cstructs.ClientCSINodeDetachVolumeRequest{ + PluginID: vol.PluginID, + VolumeID: vol.ID, + ExternalID: vol.RemoteID(), + AllocID: claim.AllocationID, + NodeID: claim.NodeID, + AttachmentMode: vol.AttachmentMode, + AccessMode: vol.AccessMode, + ReadOnly: claim.Mode == structs.CSIVolumeClaimRead, + } + + err := vw.rpc.NodeDetachVolume(nReq, + &cstructs.ClientCSINodeDetachVolumeResponse{}) + if err != nil { + return fmt.Errorf("could not detach from node: %v", err) + } + claim.State = structs.CSIVolumeClaimStateNodeDetached + return vw.checkpoint(vol, claim) +} + +// controllerDetach makes the client RPC to the controller to +// unpublish the volume if a controller is required and no other +// allocs on the node need it +func (vw *volumeWatcher) controllerDetach(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim, nodeClaims map[string]int) error { + if !vol.ControllerRequired || nodeClaims[claim.NodeID] > 1 { + claim.State = structs.CSIVolumeClaimStateReadyToFree + return nil + } + vw.logger.Trace("detaching controller", "id", vol.ID, "namespace", vol.Namespace) + // note: we need to get the CSI Node ID, which is not the same as + // the Nomad Node ID + ws := memdb.NewWatchSet() + targetNode, err := vw.state.NodeByID(ws, claim.NodeID) + if err != nil { + return err + } + if targetNode == nil { + return fmt.Errorf("%s: %s", structs.ErrUnknownNodePrefix, claim.NodeID) + } + targetCSIInfo, ok := targetNode.CSINodePlugins[vol.PluginID] + if !ok { + return fmt.Errorf("failed to find NodeInfo for node: %s", targetNode.ID) + } + + plug, err := vw.state.CSIPluginByID(ws, vol.PluginID) + if err != nil { + return fmt.Errorf("plugin lookup error: %s %v", vol.PluginID, err) + } + if plug == nil { + return fmt.Errorf("plugin lookup error: %s missing plugin", vol.PluginID) + } + + cReq := &cstructs.ClientCSIControllerDetachVolumeRequest{ + VolumeID: vol.RemoteID(), + ClientCSINodeID: targetCSIInfo.NodeInfo.ID, + } + cReq.PluginID = plug.ID + err = vw.rpc.ControllerDetachVolume(cReq, + &cstructs.ClientCSIControllerDetachVolumeResponse{}) + if err != nil { + return fmt.Errorf("could not detach from controller: %v", err) + } + claim.State = structs.CSIVolumeClaimStateReadyToFree + return nil +} + +func (vw *volumeWatcher) checkpoint(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim) error { + vw.logger.Trace("checkpointing claim", "id", vol.ID, "namespace", vol.Namespace) + req := structs.CSIVolumeClaimRequest{ + VolumeID: vol.ID, + AllocationID: claim.AllocationID, + NodeID: claim.NodeID, + Claim: structs.CSIVolumeClaimRelease, + State: claim.State, + WriteRequest: structs.WriteRequest{ + Namespace: vol.Namespace, + // Region: vol.Region, // TODO(tgross) should volumes have regions? + }, + } + index, err := vw.updateClaims([]structs.CSIVolumeClaimRequest{req}) + if err == nil && index != 0 { + vw.wLock.Lock() + defer vw.wLock.Unlock() + vw.v.ModifyIndex = index + } + if err != nil { + return fmt.Errorf("could not checkpoint claim release: %v", err) + } + return nil +} diff --git a/nomad/volumewatcher/volume_watcher_test.go b/nomad/volumewatcher/volume_watcher_test.go new file mode 100644 index 000000000000..a2b5ab033503 --- /dev/null +++ b/nomad/volumewatcher/volume_watcher_test.go @@ -0,0 +1,294 @@ +package volumewatcher + +import ( + "context" + "fmt" + "testing" + + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +// TestVolumeWatch_OneReap tests one pass through the reaper +func TestVolumeWatch_OneReap(t *testing.T) { + t.Parallel() + require := require.New(t) + + cases := []struct { + Name string + Volume *structs.CSIVolume + Node *structs.Node + ControllerRequired bool + ExpectedErr string + ExpectedClaimsCount int + ExpectedNodeDetachCount int + ExpectedControllerDetachCount int + ExpectedUpdateClaimsCount int + srv *MockRPCServer + }{ + { + Name: "No terminal allocs", + Volume: mock.CSIVolume(mock.CSIPlugin()), + ControllerRequired: true, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSINodeDetachError: fmt.Errorf("should never see this"), + }, + }, + { + Name: "NodeDetachVolume fails", + ControllerRequired: true, + ExpectedErr: "some node plugin error", + ExpectedNodeDetachCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSINodeDetachError: fmt.Errorf("some node plugin error"), + }, + }, + { + Name: "NodeDetachVolume node-only happy path", + ControllerRequired: false, + ExpectedNodeDetachCount: 1, + ExpectedUpdateClaimsCount: 2, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + { + Name: "ControllerDetachVolume no controllers available", + Node: mock.Node(), + ControllerRequired: true, + ExpectedErr: "Unknown node", + ExpectedNodeDetachCount: 1, + ExpectedUpdateClaimsCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + { + Name: "ControllerDetachVolume controller error", + ControllerRequired: true, + ExpectedErr: "some controller error", + ExpectedNodeDetachCount: 1, + ExpectedControllerDetachCount: 1, + ExpectedUpdateClaimsCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSIControllerDetachError: fmt.Errorf("some controller error"), + }, + }, + { + Name: "ControllerDetachVolume happy path", + ControllerRequired: true, + ExpectedNodeDetachCount: 1, + ExpectedControllerDetachCount: 1, + ExpectedUpdateClaimsCount: 2, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + + plugin := mock.CSIPlugin() + plugin.ControllerRequired = tc.ControllerRequired + node := testNode(tc.Node, plugin, tc.srv.State()) + alloc := mock.Alloc() + alloc.NodeID = node.ID + alloc.ClientStatus = structs.AllocClientStatusComplete + vol := testVolume(tc.Volume, plugin, alloc, node.ID) + ctx, exitFn := context.WithCancel(context.Background()) + w := &volumeWatcher{ + v: vol, + rpc: tc.srv, + state: tc.srv.State(), + updateClaims: tc.srv.UpdateClaims, + ctx: ctx, + exitFn: exitFn, + logger: testlog.HCLogger(t), + } + + err := w.volumeReapImpl(vol) + if tc.ExpectedErr != "" { + require.Error(err, fmt.Sprintf("expected: %q", tc.ExpectedErr)) + require.Contains(err.Error(), tc.ExpectedErr) + } else { + require.NoError(err) + } + require.Equal(tc.ExpectedNodeDetachCount, + tc.srv.countCSINodeDetachVolume, "node detach RPC count") + require.Equal(tc.ExpectedControllerDetachCount, + tc.srv.countCSIControllerDetachVolume, "controller detach RPC count") + require.Equal(tc.ExpectedUpdateClaimsCount, + tc.srv.countUpdateClaims, "update claims count") + }) + } +} + +// TestVolumeWatch_OldVolume_OneReap tests one pass through the reaper +// COMPAT(1.0): the claim fields were added after 0.11.1; this test +// can be removed for 1.0 +func TestVolumeWatch_OldVolume_OneReap(t *testing.T) { + t.Parallel() + require := require.New(t) + + cases := []struct { + Name string + Volume *structs.CSIVolume + Node *structs.Node + ControllerRequired bool + ExpectedErr string + ExpectedClaimsCount int + ExpectedNodeDetachCount int + ExpectedControllerDetachCount int + ExpectedUpdateClaimsCount int + srv *MockRPCServer + }{ + { + Name: "No terminal allocs", + Volume: mock.CSIVolume(mock.CSIPlugin()), + ControllerRequired: true, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSINodeDetachError: fmt.Errorf("should never see this"), + }, + }, + { + Name: "NodeDetachVolume fails", + ControllerRequired: true, + ExpectedErr: "some node plugin error", + ExpectedNodeDetachCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSINodeDetachError: fmt.Errorf("some node plugin error"), + }, + }, + { + Name: "NodeDetachVolume node-only happy path", + ControllerRequired: false, + ExpectedNodeDetachCount: 1, + ExpectedUpdateClaimsCount: 2, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + { + Name: "ControllerDetachVolume no controllers available", + Node: mock.Node(), + ControllerRequired: true, + ExpectedErr: "Unknown node", + ExpectedNodeDetachCount: 1, + ExpectedUpdateClaimsCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + { + Name: "ControllerDetachVolume controller error", + ControllerRequired: true, + ExpectedErr: "some controller error", + ExpectedNodeDetachCount: 1, + ExpectedControllerDetachCount: 1, + ExpectedUpdateClaimsCount: 1, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + nextCSIControllerDetachError: fmt.Errorf("some controller error"), + }, + }, + { + Name: "ControllerDetachVolume happy path", + ControllerRequired: true, + ExpectedNodeDetachCount: 1, + ExpectedControllerDetachCount: 1, + ExpectedUpdateClaimsCount: 2, + srv: &MockRPCServer{ + state: state.TestStateStore(t), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + + plugin := mock.CSIPlugin() + plugin.ControllerRequired = tc.ControllerRequired + node := testNode(tc.Node, plugin, tc.srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + alloc.NodeID = node.ID + vol := testOldVolume(tc.Volume, plugin, alloc, node.ID) + ctx, exitFn := context.WithCancel(context.Background()) + w := &volumeWatcher{ + v: vol, + rpc: tc.srv, + state: tc.srv.State(), + updateClaims: tc.srv.UpdateClaims, + ctx: ctx, + exitFn: exitFn, + logger: testlog.HCLogger(t), + } + + err := w.volumeReapImpl(vol) + if tc.ExpectedErr != "" { + require.Error(err, fmt.Sprintf("expected: %q", tc.ExpectedErr)) + require.Contains(err.Error(), tc.ExpectedErr) + } else { + require.NoError(err) + } + require.Equal(tc.ExpectedNodeDetachCount, + tc.srv.countCSINodeDetachVolume, "node detach RPC count") + require.Equal(tc.ExpectedControllerDetachCount, + tc.srv.countCSIControllerDetachVolume, "controller detach RPC count") + require.Equal(tc.ExpectedUpdateClaimsCount, + tc.srv.countUpdateClaims, "update claims count") + }) + } +} + +// TestVolumeWatch_OneReap tests multiple passes through the reaper, +// updating state after each one +func TestVolumeWatch_ReapStates(t *testing.T) { + t.Parallel() + require := require.New(t) + + srv := &MockRPCServer{state: state.TestStateStore(t)} + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + vol := testVolume(nil, plugin, alloc, node.ID) + + w := &volumeWatcher{ + v: vol, + rpc: srv, + state: srv.State(), + updateClaims: srv.UpdateClaims, + logger: testlog.HCLogger(t), + } + + srv.nextCSINodeDetachError = fmt.Errorf("some node plugin error") + err := w.volumeReapImpl(vol) + require.Error(err) + require.Equal(structs.CSIVolumeClaimStateTaken, vol.PastClaims[alloc.ID].State) + require.Equal(1, srv.countCSINodeDetachVolume) + require.Equal(0, srv.countCSIControllerDetachVolume) + require.Equal(0, srv.countUpdateClaims) + + srv.nextCSINodeDetachError = nil + srv.nextCSIControllerDetachError = fmt.Errorf("some controller plugin error") + err = w.volumeReapImpl(vol) + require.Error(err) + require.Equal(structs.CSIVolumeClaimStateNodeDetached, vol.PastClaims[alloc.ID].State) + require.Equal(1, srv.countUpdateClaims) + + srv.nextCSIControllerDetachError = nil + err = w.volumeReapImpl(vol) + require.NoError(err) + require.Equal(0, len(vol.PastClaims)) + require.Equal(2, srv.countUpdateClaims) +} diff --git a/nomad/volumewatcher/volumes_watcher.go b/nomad/volumewatcher/volumes_watcher.go new file mode 100644 index 000000000000..63446c461ac1 --- /dev/null +++ b/nomad/volumewatcher/volumes_watcher.go @@ -0,0 +1,232 @@ +package volumewatcher + +import ( + "context" + "sync" + "time" + + log "github.com/hashicorp/go-hclog" + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "golang.org/x/time/rate" +) + +const ( + // LimitStateQueriesPerSecond is the number of state queries allowed per + // second + LimitStateQueriesPerSecond = 100.0 + + // CrossVolumeUpdateBatchDuration is the duration in which volume + // claim updates are batched across all volume watchers before + // being committed to Raft. + CrossVolumeUpdateBatchDuration = 250 * time.Millisecond +) + +// Watcher is used to watch volumes and their allocations created +// by the scheduler and trigger the scheduler when allocation health +// transitions. +type Watcher struct { + enabled bool + logger log.Logger + + // queryLimiter is used to limit the rate of blocking queries + queryLimiter *rate.Limiter + + // updateBatchDuration is the duration in which volume + // claim updates are batched across all volume watchers + // before being committed to Raft. + updateBatchDuration time.Duration + + // raft contains the set of Raft endpoints that can be used by the + // volumes watcher + raft VolumeRaftEndpoints + + // rpc contains the set of Server methods that can be used by + // the volumes watcher for RPC + rpc ClientRPC + + // state is the state that is watched for state changes. + state *state.StateStore + + // watchers is the set of active watchers, one per volume + watchers map[string]*volumeWatcher + + // volumeUpdateBatcher is used to batch volume claim updates + volumeUpdateBatcher *VolumeUpdateBatcher + + // ctx and exitFn are used to cancel the watcher + ctx context.Context + exitFn context.CancelFunc + + wlock sync.RWMutex +} + +// NewVolumesWatcher returns a volumes watcher that is used to watch +// volumes and trigger the scheduler as needed. +func NewVolumesWatcher(logger log.Logger, + raft VolumeRaftEndpoints, rpc ClientRPC, stateQueriesPerSecond float64, + updateBatchDuration time.Duration) *Watcher { + + // the leader step-down calls SetEnabled(false) which is what + // cancels this context, rather than passing in its own shutdown + // context + ctx, exitFn := context.WithCancel(context.Background()) + + return &Watcher{ + raft: raft, + rpc: rpc, + queryLimiter: rate.NewLimiter(rate.Limit(stateQueriesPerSecond), 100), + updateBatchDuration: updateBatchDuration, + logger: logger.Named("volumes_watcher"), + ctx: ctx, + exitFn: exitFn, + } +} + +// SetEnabled is used to control if the watcher is enabled. The +// watcher should only be enabled on the active leader. When being +// enabled the state is passed in as it is no longer valid once a +// leader election has taken place. +func (w *Watcher) SetEnabled(enabled bool, state *state.StateStore) { + w.wlock.Lock() + defer w.wlock.Unlock() + + wasEnabled := w.enabled + w.enabled = enabled + + if state != nil { + w.state = state + } + + // Flush the state to create the necessary objects + w.flush() + + // If we are starting now, launch the watch daemon + if enabled && !wasEnabled { + go w.watchVolumes(w.ctx) + } +} + +// flush is used to clear the state of the watcher +func (w *Watcher) flush() { + // Stop all the watchers and clear it + for _, watcher := range w.watchers { + watcher.Stop() + } + + // Kill everything associated with the watcher + if w.exitFn != nil { + w.exitFn() + } + + w.watchers = make(map[string]*volumeWatcher, 32) + w.ctx, w.exitFn = context.WithCancel(context.Background()) + w.volumeUpdateBatcher = NewVolumeUpdateBatcher(w.updateBatchDuration, w.raft, w.ctx) +} + +// watchVolumes is the long lived go-routine that watches for volumes to +// add and remove watchers on. +func (w *Watcher) watchVolumes(ctx context.Context) { + vIndex := uint64(1) + for { + volumes, idx, err := w.getVolumes(ctx, vIndex) + if err != nil { + if err == context.Canceled { + return + } + w.logger.Error("failed to retrieve volumes", "error", err) + } + + vIndex = idx // last-seen index + for _, v := range volumes { + if err := w.add(v); err != nil { + w.logger.Error("failed to track volume", "volume_id", v.ID, "error", err) + } + + } + } +} + +// getVolumes retrieves all volumes blocking at the given index. +func (w *Watcher) getVolumes(ctx context.Context, minIndex uint64) ([]*structs.CSIVolume, uint64, error) { + resp, index, err := w.state.BlockingQuery(w.getVolumesImpl, minIndex, ctx) + if err != nil { + return nil, 0, err + } + + return resp.([]*structs.CSIVolume), index, nil +} + +// getVolumesImpl retrieves all volumes from the passed state store. +func (w *Watcher) getVolumesImpl(ws memdb.WatchSet, state *state.StateStore) (interface{}, uint64, error) { + + iter, err := state.CSIVolumes(ws) + if err != nil { + return nil, 0, err + } + + var volumes []*structs.CSIVolume + for { + raw := iter.Next() + if raw == nil { + break + } + volume := raw.(*structs.CSIVolume) + volumes = append(volumes, volume) + } + + // Use the last index that affected the volume table + index, err := state.Index("csi_volumes") + if err != nil { + return nil, 0, err + } + + return volumes, index, nil +} + +// add adds a volume to the watch list +func (w *Watcher) add(d *structs.CSIVolume) error { + w.wlock.Lock() + defer w.wlock.Unlock() + _, err := w.addLocked(d) + return err +} + +// addLocked adds a volume to the watch list and should only be called when +// locked. Creating the volumeWatcher starts a go routine to .watch() it +func (w *Watcher) addLocked(v *structs.CSIVolume) (*volumeWatcher, error) { + // Not enabled so no-op + if !w.enabled { + return nil, nil + } + + // Already watched so trigger an update for the volume + if watcher, ok := w.watchers[v.ID+v.Namespace]; ok { + watcher.Notify(v) + return nil, nil + } + + watcher := newVolumeWatcher(w, v) + w.watchers[v.ID+v.Namespace] = watcher + return watcher, nil +} + +// TODO: this is currently dead code; we'll call a public remove +// method on the Watcher once we have a periodic GC job +// remove stops watching a volume and should only be called when locked. +func (w *Watcher) removeLocked(volID, namespace string) { + if !w.enabled { + return + } + if watcher, ok := w.watchers[volID+namespace]; ok { + watcher.Stop() + delete(w.watchers, volID+namespace) + } +} + +// updatesClaims sends the claims to the batch updater and waits for +// the results +func (w *Watcher) updateClaims(claims []structs.CSIVolumeClaimRequest) (uint64, error) { + return w.volumeUpdateBatcher.CreateUpdate(claims).Results() +} diff --git a/nomad/volumewatcher/volumes_watcher_test.go b/nomad/volumewatcher/volumes_watcher_test.go new file mode 100644 index 000000000000..b7ae7aea2c55 --- /dev/null +++ b/nomad/volumewatcher/volumes_watcher_test.go @@ -0,0 +1,310 @@ +package volumewatcher + +import ( + "context" + "testing" + "time" + + memdb "github.com/hashicorp/go-memdb" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/state" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/stretchr/testify/require" +) + +// TestVolumeWatch_EnableDisable tests the watcher registration logic that needs +// to happen during leader step-up/step-down +func TestVolumeWatch_EnableDisable(t *testing.T) { + t.Parallel() + require := require.New(t) + + srv := &MockRPCServer{} + srv.state = state.TestStateStore(t) + index := uint64(100) + + watcher := NewVolumesWatcher(testlog.HCLogger(t), + srv, srv, + LimitStateQueriesPerSecond, + CrossVolumeUpdateBatchDuration) + + watcher.SetEnabled(true, srv.State()) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + vol := testVolume(nil, plugin, alloc, node.ID) + + index++ + err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) + require.NoError(err) + + claim := &structs.CSIVolumeClaim{Mode: structs.CSIVolumeClaimRelease} + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + + watcher.SetEnabled(false, srv.State()) + require.Equal(0, len(watcher.watchers)) +} + +// TestVolumeWatch_Checkpoint tests the checkpointing of progress across +// leader leader step-up/step-down +func TestVolumeWatch_Checkpoint(t *testing.T) { + t.Parallel() + require := require.New(t) + + srv := &MockRPCServer{} + srv.state = state.TestStateStore(t) + index := uint64(100) + + watcher := NewVolumesWatcher(testlog.HCLogger(t), + srv, srv, + LimitStateQueriesPerSecond, + CrossVolumeUpdateBatchDuration) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + vol := testVolume(nil, plugin, alloc, node.ID) + + watcher.SetEnabled(true, srv.State()) + + index++ + err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) + require.NoError(err) + + // we should get or start up a watcher when we get an update for + // the volume from the state store + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + + // step-down (this is sync, but step-up is async) + watcher.SetEnabled(false, srv.State()) + require.Equal(0, len(watcher.watchers)) + + // step-up again + watcher.SetEnabled(true, srv.State()) + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + + require.True(watcher.watchers[vol.ID+vol.Namespace].isRunning()) +} + +// TestVolumeWatch_StartStop tests the start and stop of the watcher when +// it receives notifcations and has completed its work +func TestVolumeWatch_StartStop(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx, exitFn := context.WithCancel(context.Background()) + defer exitFn() + + srv := &MockStatefulRPCServer{} + srv.state = state.TestStateStore(t) + index := uint64(100) + srv.volumeUpdateBatcher = NewVolumeUpdateBatcher( + CrossVolumeUpdateBatchDuration, srv, ctx) + + watcher := NewVolumesWatcher(testlog.HCLogger(t), + srv, srv, + LimitStateQueriesPerSecond, + CrossVolumeUpdateBatchDuration) + + watcher.SetEnabled(true, srv.State()) + require.Equal(0, len(watcher.watchers)) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusRunning + alloc2 := mock.Alloc() + alloc2.Job = alloc.Job + alloc2.ClientStatus = structs.AllocClientStatusRunning + index++ + err := srv.State().UpsertJob(index, alloc.Job) + require.NoError(err) + index++ + err = srv.State().UpsertAllocs(index, []*structs.Allocation{alloc, alloc2}) + require.NoError(err) + + // register a volume + vol := testVolume(nil, plugin, alloc, node.ID) + index++ + err = srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) + require.NoError(err) + + // assert we get a running watcher + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + require.True(watcher.watchers[vol.ID+vol.Namespace].isRunning()) + + // claim the volume for both allocs + claim := &structs.CSIVolumeClaim{ + AllocationID: alloc.ID, + NodeID: node.ID, + Mode: structs.CSIVolumeClaimRead, + } + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + claim.AllocationID = alloc2.ID + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + + // reap the volume and assert nothing has happened + claim = &structs.CSIVolumeClaim{ + AllocationID: alloc.ID, + NodeID: node.ID, + Mode: structs.CSIVolumeClaimRelease, + } + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + require.True(watcher.watchers[vol.ID+vol.Namespace].isRunning()) + + // alloc becomes terminal + alloc.ClientStatus = structs.AllocClientStatusComplete + index++ + err = srv.State().UpsertAllocs(index, []*structs.Allocation{alloc}) + require.NoError(err) + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + + // 1 claim has been released but watcher is still running + require.Eventually(func() bool { + ws := memdb.NewWatchSet() + vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) + return len(vol.ReadAllocs) == 1 && len(vol.PastClaims) == 0 + }, time.Second*2, 10*time.Millisecond) + + require.True(watcher.watchers[vol.ID+vol.Namespace].isRunning()) + + // the watcher will have incremented the index so we need to make sure + // our inserts will trigger new events + index, _ = srv.State().LatestIndex() + + // remaining alloc's job is stopped (alloc is not marked terminal) + alloc2.Job.Stop = true + index++ + err = srv.State().UpsertJob(index, alloc2.Job) + require.NoError(err) + + // job deregistration write a claim with no allocations or nodes + claim = &structs.CSIVolumeClaim{ + Mode: structs.CSIVolumeClaimRelease, + } + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim) + require.NoError(err) + + // all claims have been released and watcher is stopped + require.Eventually(func() bool { + ws := memdb.NewWatchSet() + vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) + return len(vol.ReadAllocs) == 1 && len(vol.PastClaims) == 0 + }, time.Second*2, 10*time.Millisecond) + + require.Eventually(func() bool { + return !watcher.watchers[vol.ID+vol.Namespace].isRunning() + }, time.Second*1, 10*time.Millisecond) + + // the watcher will have incremented the index so we need to make sure + // our inserts will trigger new events + index, _ = srv.State().LatestIndex() + + // create a new claim + alloc3 := mock.Alloc() + alloc3.ClientStatus = structs.AllocClientStatusRunning + index++ + err = srv.State().UpsertAllocs(index, []*structs.Allocation{alloc3}) + require.NoError(err) + claim3 := &structs.CSIVolumeClaim{ + AllocationID: alloc3.ID, + NodeID: node.ID, + Mode: structs.CSIVolumeClaimRelease, + } + index++ + err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim3) + require.NoError(err) + + // a stopped watcher should restore itself on notification + require.Eventually(func() bool { + return watcher.watchers[vol.ID+vol.Namespace].isRunning() + }, time.Second*1, 10*time.Millisecond) +} + +// TestVolumeWatch_RegisterDeregister tests the start and stop of +// watchers around registration +func TestVolumeWatch_RegisterDeregister(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx, exitFn := context.WithCancel(context.Background()) + defer exitFn() + + srv := &MockStatefulRPCServer{} + srv.state = state.TestStateStore(t) + srv.volumeUpdateBatcher = NewVolumeUpdateBatcher( + CrossVolumeUpdateBatchDuration, srv, ctx) + + index := uint64(100) + + watcher := NewVolumesWatcher(testlog.HCLogger(t), + srv, srv, + LimitStateQueriesPerSecond, + CrossVolumeUpdateBatchDuration) + + watcher.SetEnabled(true, srv.State()) + require.Equal(0, len(watcher.watchers)) + + plugin := mock.CSIPlugin() + node := testNode(nil, plugin, srv.State()) + alloc := mock.Alloc() + alloc.ClientStatus = structs.AllocClientStatusComplete + + // register a volume + vol := testVolume(nil, plugin, alloc, node.ID) + index++ + err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol}) + require.NoError(err) + + require.Eventually(func() bool { + return 1 == len(watcher.watchers) + }, time.Second, 10*time.Millisecond) + + // reap the volume and assert we've cleaned up + w := watcher.watchers[vol.ID+vol.Namespace] + w.Notify(vol) + + require.Eventually(func() bool { + ws := memdb.NewWatchSet() + vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID) + return len(vol.ReadAllocs) == 0 && len(vol.PastClaims) == 0 + }, time.Second*2, 10*time.Millisecond) + + require.Eventually(func() bool { + return !watcher.watchers[vol.ID+vol.Namespace].isRunning() + }, time.Second*1, 10*time.Millisecond) + + require.Equal(1, srv.countCSINodeDetachVolume, "node detach RPC count") + require.Equal(1, srv.countCSIControllerDetachVolume, "controller detach RPC count") + require.Equal(2, srv.countUpsertVolumeClaims, "upsert claims count") + + // deregistering the volume doesn't cause an update that triggers + // a watcher; we'll clean up this watcher in a GC later + err = srv.State().CSIVolumeDeregister(index, vol.Namespace, []string{vol.ID}) + require.NoError(err) + require.Equal(1, len(watcher.watchers)) + require.False(watcher.watchers[vol.ID+vol.Namespace].isRunning()) +} diff --git a/nomad/volumewatcher_shim.go b/nomad/volumewatcher_shim.go new file mode 100644 index 000000000000..5148d7f5ba9d --- /dev/null +++ b/nomad/volumewatcher_shim.go @@ -0,0 +1,31 @@ +package nomad + +import ( + "github.com/hashicorp/nomad/nomad/structs" +) + +// volumeWatcherRaftShim is the shim that provides the state watching +// methods. These should be set by the server and passed to the volume +// watcher. +type volumeWatcherRaftShim struct { + // apply is used to apply a message to Raft + apply raftApplyFn +} + +// convertApplyErrors parses the results of a raftApply and returns the index at +// which it was applied and any error that occurred. Raft Apply returns two +// separate errors, Raft library errors and user returned errors from the FSM. +// This helper, joins the errors by inspecting the applyResponse for an error. +func (shim *volumeWatcherRaftShim) convertApplyErrors(applyResp interface{}, index uint64, err error) (uint64, error) { + if applyResp != nil { + if fsmErr, ok := applyResp.(error); ok && fsmErr != nil { + return index, fsmErr + } + } + return index, err +} + +func (shim *volumeWatcherRaftShim) UpsertVolumeClaims(req *structs.CSIVolumeClaimBatchRequest) (uint64, error) { + fsmErrIntf, index, raftErr := shim.apply(structs.CSIVolumeClaimBatchRequestType, req) + return shim.convertApplyErrors(fsmErrIntf, index, raftErr) +}