diff --git a/client/allocrunner/csi_hook.go b/client/allocrunner/csi_hook.go index 5b35bc81bb15..ec79a6d9f92a 100644 --- a/client/allocrunner/csi_hook.go +++ b/client/allocrunner/csi_hook.go @@ -104,6 +104,7 @@ func (c *csiHook) claimVolumesFromAlloc() (map[string]*volumeAndRequest, error) req := &structs.CSIVolumeClaimRequest{ VolumeID: pair.request.Source, AllocationID: c.alloc.ID, + NodeID: c.alloc.NodeID, Claim: claimType, } req.Region = c.alloc.Job.Region diff --git a/client/pluginmanager/csimanager/instance_test.go b/client/pluginmanager/csimanager/instance_test.go index 6a8658df5abf..2f05ddc3f54b 100644 --- a/client/pluginmanager/csimanager/instance_test.go +++ b/client/pluginmanager/csimanager/instance_test.go @@ -2,7 +2,6 @@ package csimanager import ( "context" - "fmt" "sync" "testing" "time" @@ -47,7 +46,6 @@ func TestInstanceManager_Shutdown(t *testing.T) { im.shutdownCtxCancelFn = cancelFn im.shutdownCh = make(chan struct{}) im.updater = func(_ string, info *structs.CSIInfo) { - fmt.Println(info) lock.Lock() defer lock.Unlock() pluginHealth = info.Healthy diff --git a/nomad/core_sched.go b/nomad/core_sched.go index 708c8226feba..0b1d7e62420f 100644 --- a/nomad/core_sched.go +++ b/nomad/core_sched.go @@ -749,17 +749,15 @@ func volumeClaimReap(srv RPCServer, volID, namespace, region, leaderACL string, return err } - gcClaims, nodeClaims := collectClaimsToGCImpl(vol, runningAllocs) + nodeClaims := collectClaimsToGCImpl(vol, runningAllocs) var result *multierror.Error - for _, claim := range gcClaims { + for _, claim := range vol.PastClaims { nodeClaims, err = volumeClaimReapImpl(srv, &volumeClaimReapArgs{ vol: vol, plug: plug, - allocID: claim.allocID, - nodeID: claim.nodeID, - mode: claim.mode, + claim: claim, namespace: namespace, region: region, leaderACL: leaderACL, @@ -775,48 +773,47 @@ func volumeClaimReap(srv RPCServer, volID, namespace, region, leaderACL string, } -type gcClaimRequest struct { - allocID string - nodeID string - mode structs.CSIVolumeClaimMode -} - -func collectClaimsToGCImpl(vol *structs.CSIVolume, runningAllocs bool) ([]gcClaimRequest, map[string]int) { - gcAllocs := []gcClaimRequest{} +func collectClaimsToGCImpl(vol *structs.CSIVolume, runningAllocs bool) map[string]int { nodeClaims := map[string]int{} // node IDs -> count collectFunc := func(allocs map[string]*structs.Allocation, - mode structs.CSIVolumeClaimMode) { - for _, alloc := range allocs { - // we call denormalize on the volume above to populate - // Allocation pointers. But the alloc might have been - // garbage collected concurrently, so if the alloc is - // still nil we can safely skip it. - if alloc == nil { - continue + claims map[string]*structs.CSIVolumeClaim) { + + for allocID, alloc := range allocs { + claim, ok := claims[allocID] + if !ok { + // COMPAT(1.0): the CSIVolumeClaim fields were added + // after 0.11.1, so claims made before that may be + // missing this value. note that we'll have non-nil + // allocs here because we called denormalize on the + // value. + claim = &structs.CSIVolumeClaim{ + AllocationID: allocID, + NodeID: alloc.NodeID, + State: structs.CSIVolumeClaimStateTaken, + } } - nodeClaims[alloc.NodeID]++ + nodeClaims[claim.NodeID]++ if runningAllocs || alloc.Terminated() { - gcAllocs = append(gcAllocs, gcClaimRequest{ - allocID: alloc.ID, - nodeID: alloc.NodeID, - mode: mode, - }) + // only overwrite the PastClaim if this is new, + // so that we can track state between subsequent calls + if _, exists := vol.PastClaims[claim.AllocationID]; !exists { + claim.State = structs.CSIVolumeClaimStateTaken + vol.PastClaims[claim.AllocationID] = claim + } } } } - collectFunc(vol.WriteAllocs, structs.CSIVolumeClaimWrite) - collectFunc(vol.ReadAllocs, structs.CSIVolumeClaimRead) - return gcAllocs, nodeClaims + collectFunc(vol.WriteAllocs, vol.WriteClaims) + collectFunc(vol.ReadAllocs, vol.ReadClaims) + return nodeClaims } type volumeClaimReapArgs struct { vol *structs.CSIVolume plug *structs.CSIPlugin - allocID string - nodeID string - mode structs.CSIVolumeClaimMode + claim *structs.CSIVolumeClaim region string namespace string leaderACL string @@ -825,42 +822,78 @@ type volumeClaimReapArgs struct { func volumeClaimReapImpl(srv RPCServer, args *volumeClaimReapArgs) (map[string]int, error) { vol := args.vol - nodeID := args.nodeID + claim := args.claim + + var err error + var nReq *cstructs.ClientCSINodeDetachVolumeRequest + + checkpoint := func(claimState structs.CSIVolumeClaimState) error { + req := &structs.CSIVolumeClaimRequest{ + VolumeID: vol.ID, + AllocationID: claim.AllocationID, + Claim: structs.CSIVolumeClaimRelease, + WriteRequest: structs.WriteRequest{ + Region: args.region, + Namespace: args.namespace, + AuthToken: args.leaderACL, + }, + } + return srv.RPC("CSIVolume.Claim", req, &structs.CSIVolumeClaimResponse{}) + } + + // previous checkpoints may have set the past claim state already. + // in practice we should never see CSIVolumeClaimStateControllerDetached + // but having an option for the state makes it easy to add a checkpoint + // in a backwards compatible way if we need one later + switch claim.State { + case structs.CSIVolumeClaimStateNodeDetached: + goto NODE_DETACHED + case structs.CSIVolumeClaimStateControllerDetached: + goto RELEASE_CLAIM + case structs.CSIVolumeClaimStateReadyToFree: + goto RELEASE_CLAIM + } // (1) NodePublish / NodeUnstage must be completed before controller // operations or releasing the claim. - nReq := &cstructs.ClientCSINodeDetachVolumeRequest{ + nReq = &cstructs.ClientCSINodeDetachVolumeRequest{ PluginID: args.plug.ID, VolumeID: vol.ID, ExternalID: vol.RemoteID(), - AllocID: args.allocID, - NodeID: nodeID, + AllocID: claim.AllocationID, + NodeID: claim.NodeID, AttachmentMode: vol.AttachmentMode, AccessMode: vol.AccessMode, - ReadOnly: args.mode == structs.CSIVolumeClaimRead, + ReadOnly: claim.Mode == structs.CSIVolumeClaimRead, } - err := srv.RPC("ClientCSI.NodeDetachVolume", nReq, + err = srv.RPC("ClientCSI.NodeDetachVolume", nReq, &cstructs.ClientCSINodeDetachVolumeResponse{}) if err != nil { return args.nodeClaims, err } - args.nodeClaims[nodeID]-- + err = checkpoint(structs.CSIVolumeClaimStateNodeDetached) + if err != nil { + return args.nodeClaims, err + } + +NODE_DETACHED: + args.nodeClaims[claim.NodeID]-- // (2) we only emit the controller unpublish if no other allocs // on the node need it, but we also only want to make this // call at most once per node - if vol.ControllerRequired && args.nodeClaims[nodeID] < 1 { + if vol.ControllerRequired && args.nodeClaims[claim.NodeID] < 1 { // we need to get the CSI Node ID, which is not the same as // the Nomad Node ID ws := memdb.NewWatchSet() - targetNode, err := srv.State().NodeByID(ws, nodeID) + targetNode, err := srv.State().NodeByID(ws, claim.NodeID) if err != nil { return args.nodeClaims, err } if targetNode == nil { return args.nodeClaims, fmt.Errorf("%s: %s", - structs.ErrUnknownNodePrefix, nodeID) + structs.ErrUnknownNodePrefix, claim.NodeID) } targetCSIInfo, ok := targetNode.CSINodePlugins[args.plug.ID] if !ok { @@ -879,18 +912,9 @@ func volumeClaimReapImpl(srv RPCServer, args *volumeClaimReapArgs) (map[string]i } } +RELEASE_CLAIM: // (3) release the claim from the state store, allowing it to be rescheduled - req := &structs.CSIVolumeClaimRequest{ - VolumeID: vol.ID, - AllocationID: args.allocID, - Claim: structs.CSIVolumeClaimRelease, - WriteRequest: structs.WriteRequest{ - Region: args.region, - Namespace: args.namespace, - AuthToken: args.leaderACL, - }, - } - err = srv.RPC("CSIVolume.Claim", req, &structs.CSIVolumeClaimResponse{}) + err = checkpoint(structs.CSIVolumeClaimStateReadyToFree) if err != nil { return args.nodeClaims, err } diff --git a/nomad/core_sched_test.go b/nomad/core_sched_test.go index 890999192985..819e0908ddbc 100644 --- a/nomad/core_sched_test.go +++ b/nomad/core_sched_test.go @@ -2286,12 +2286,22 @@ func TestCSI_GCVolumeClaims_Collection(t *testing.T) { require.NoError(t, err) // Claim the volumes and verify the claims were set - err = state.CSIVolumeClaim(index, ns, volId0, alloc1, structs.CSIVolumeClaimWrite) + err = state.CSIVolumeClaim(index, ns, volId0, &structs.CSIVolumeClaim{ + AllocationID: alloc1.ID, + NodeID: alloc1.NodeID, + Mode: structs.CSIVolumeClaimWrite, + }) index++ require.NoError(t, err) - err = state.CSIVolumeClaim(index, ns, volId0, alloc2, structs.CSIVolumeClaimRead) + + err = state.CSIVolumeClaim(index, ns, volId0, &structs.CSIVolumeClaim{ + AllocationID: alloc2.ID, + NodeID: alloc2.NodeID, + Mode: structs.CSIVolumeClaimRead, + }) index++ require.NoError(t, err) + vol, err = state.CSIVolumeByID(ws, ns, volId0) require.NoError(t, err) require.Len(t, vol.ReadAllocs, 1) @@ -2306,9 +2316,9 @@ func TestCSI_GCVolumeClaims_Collection(t *testing.T) { vol, err = state.CSIVolumeDenormalize(ws, vol) require.NoError(t, err) - gcClaims, nodeClaims := collectClaimsToGCImpl(vol, false) + nodeClaims := collectClaimsToGCImpl(vol, false) require.Equal(t, nodeClaims[node.ID], 2) - require.Len(t, gcClaims, 2) + require.Len(t, vol.PastClaims, 2) } func TestCSI_GCVolumeClaims_Reap(t *testing.T) { @@ -2326,7 +2336,6 @@ func TestCSI_GCVolumeClaims_Reap(t *testing.T) { cases := []struct { Name string - Claim gcClaimRequest ClaimsCount map[string]int ControllerRequired bool ExpectedErr string @@ -2338,12 +2347,7 @@ func TestCSI_GCVolumeClaims_Reap(t *testing.T) { srv *MockRPCServer }{ { - Name: "NodeDetachVolume fails", - Claim: gcClaimRequest{ - allocID: alloc.ID, - nodeID: node.ID, - mode: structs.CSIVolumeClaimRead, - }, + Name: "NodeDetachVolume fails", ClaimsCount: map[string]int{node.ID: 1}, ControllerRequired: true, ExpectedErr: "node plugin missing", @@ -2355,36 +2359,26 @@ func TestCSI_GCVolumeClaims_Reap(t *testing.T) { }, }, { - Name: "ControllerDetachVolume no controllers", - Claim: gcClaimRequest{ - allocID: alloc.ID, - nodeID: node.ID, - mode: structs.CSIVolumeClaimRead, - }, - ClaimsCount: map[string]int{node.ID: 1}, - ControllerRequired: true, - ExpectedErr: fmt.Sprintf( - "Unknown node: %s", node.ID), + Name: "ControllerDetachVolume no controllers", + ClaimsCount: map[string]int{node.ID: 1}, + ControllerRequired: true, + ExpectedErr: fmt.Sprintf("Unknown node: %s", node.ID), ExpectedClaimsCount: 0, ExpectedNodeDetachVolumeCount: 1, ExpectedControllerDetachVolumeCount: 0, + ExpectedVolumeClaimCount: 1, srv: &MockRPCServer{ state: s.State(), }, }, { - Name: "ControllerDetachVolume node-only", - Claim: gcClaimRequest{ - allocID: alloc.ID, - nodeID: node.ID, - mode: structs.CSIVolumeClaimRead, - }, + Name: "ControllerDetachVolume node-only", ClaimsCount: map[string]int{node.ID: 1}, ControllerRequired: false, ExpectedClaimsCount: 0, ExpectedNodeDetachVolumeCount: 1, ExpectedControllerDetachVolumeCount: 0, - ExpectedVolumeClaimCount: 1, + ExpectedVolumeClaimCount: 2, srv: &MockRPCServer{ state: s.State(), }, @@ -2394,12 +2388,16 @@ func TestCSI_GCVolumeClaims_Reap(t *testing.T) { for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { vol.ControllerRequired = tc.ControllerRequired + claim := &structs.CSIVolumeClaim{ + AllocationID: alloc.ID, + NodeID: node.ID, + State: structs.CSIVolumeClaimStateTaken, + Mode: structs.CSIVolumeClaimRead, + } nodeClaims, err := volumeClaimReapImpl(tc.srv, &volumeClaimReapArgs{ vol: vol, plug: plugin, - allocID: tc.Claim.allocID, - nodeID: tc.Claim.nodeID, - mode: tc.Claim.mode, + claim: claim, region: "global", namespace: "default", leaderACL: "not-in-use", @@ -2411,7 +2409,7 @@ func TestCSI_GCVolumeClaims_Reap(t *testing.T) { require.NoError(err) } require.Equal(tc.ExpectedClaimsCount, - nodeClaims[tc.Claim.nodeID], "expected claims") + nodeClaims[claim.NodeID], "expected claims remaining") require.Equal(tc.ExpectedNodeDetachVolumeCount, tc.srv.countCSINodeDetachVolume, "node detach RPC count") require.Equal(tc.ExpectedControllerDetachVolumeCount, diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index 6efea4d2ba87..79e1a145f9b2 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -400,6 +400,7 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, return nil } + // get Nomad's ID for the client node (not the storage provider's ID) targetNode, err := state.NodeByID(ws, alloc.NodeID) if err != nil { return err @@ -407,15 +408,19 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, if targetNode == nil { return fmt.Errorf("%s: %s", structs.ErrUnknownNodePrefix, alloc.NodeID) } + + // get the the storage provider's ID for the client node (not + // Nomad's ID for the node) targetCSIInfo, ok := targetNode.CSINodePlugins[plug.ID] if !ok { return fmt.Errorf("Failed to find NodeInfo for node: %s", targetNode.ID) } + externalNodeID := targetCSIInfo.NodeInfo.ID method := "ClientCSI.ControllerAttachVolume" cReq := &cstructs.ClientCSIControllerAttachVolumeRequest{ VolumeID: vol.RemoteID(), - ClientCSINodeID: targetCSIInfo.NodeInfo.ID, + ClientCSINodeID: externalNodeID, AttachmentMode: vol.AttachmentMode, AccessMode: vol.AccessMode, ReadOnly: req.Claim == structs.CSIVolumeClaimRead, diff --git a/nomad/fsm.go b/nomad/fsm.go index 97384a4a7d43..9ec1ef086510 100644 --- a/nomad/fsm.go +++ b/nomad/fsm.go @@ -1178,7 +1178,7 @@ func (n *nomadFSM) applyCSIVolumeClaim(buf []byte, index uint64) interface{} { return structs.ErrUnknownAllocationPrefix } - if err := n.state.CSIVolumeClaim(index, req.RequestNamespace(), req.VolumeID, alloc, req.Claim); err != nil { + if err := n.state.CSIVolumeClaim(index, req.RequestNamespace(), req.VolumeID, req.ToClaim()); err != nil { n.logger.Error("CSIVolumeClaim failed", "error", err) return err } diff --git a/nomad/mock/mock.go b/nomad/mock/mock.go index d37dbdf5d1e7..ff25749c29db 100644 --- a/nomad/mock/mock.go +++ b/nomad/mock/mock.go @@ -1313,6 +1313,9 @@ func CSIVolume(plugin *structs.CSIPlugin) *structs.CSIVolume { MountOptions: &structs.CSIMountOptions{}, ReadAllocs: map[string]*structs.Allocation{}, WriteAllocs: map[string]*structs.Allocation{}, + ReadClaims: map[string]*structs.CSIVolumeClaim{}, + WriteClaims: map[string]*structs.CSIVolumeClaim{}, + PastClaims: map[string]*structs.CSIVolumeClaim{}, PluginID: plugin.ID, Provider: plugin.Provider, ProviderVersion: plugin.Version, diff --git a/nomad/node_endpoint_test.go b/nomad/node_endpoint_test.go index 9c1c5746b467..c1d54ebc8491 100644 --- a/nomad/node_endpoint_test.go +++ b/nomad/node_endpoint_test.go @@ -2381,9 +2381,17 @@ func TestClientEndpoint_UpdateAlloc_UnclaimVolumes(t *testing.T) { require.NoError(t, err) // Claim the volumes and verify the claims were set - err = state.CSIVolumeClaim(105, ns, volId0, alloc1, structs.CSIVolumeClaimWrite) + err = state.CSIVolumeClaim(105, ns, volId0, &structs.CSIVolumeClaim{ + AllocationID: alloc1.ID, + NodeID: alloc1.NodeID, + Mode: structs.CSIVolumeClaimWrite, + }) require.NoError(t, err) - err = state.CSIVolumeClaim(106, ns, volId0, alloc2, structs.CSIVolumeClaimRead) + err = state.CSIVolumeClaim(106, ns, volId0, &structs.CSIVolumeClaim{ + AllocationID: alloc2.ID, + NodeID: alloc2.NodeID, + Mode: structs.CSIVolumeClaimRead, + }) require.NoError(t, err) vol, err = state.CSIVolumeByID(ws, ns, volId0) require.NoError(t, err) diff --git a/nomad/state/state_store.go b/nomad/state/state_store.go index 0ff24b381670..22b9801ca708 100644 --- a/nomad/state/state_store.go +++ b/nomad/state/state_store.go @@ -2025,9 +2025,10 @@ func (s *StateStore) CSIVolumesByNamespace(ws memdb.WatchSet, namespace string) } // CSIVolumeClaim updates the volume's claim count and allocation list -func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, alloc *structs.Allocation, claim structs.CSIVolumeClaimMode) error { +func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, claim *structs.CSIVolumeClaim) error { txn := s.db.Txn(true) defer txn.Abort() + ws := memdb.NewWatchSet() row, err := txn.First("csi_volumes", "id", namespace, id) if err != nil { @@ -2042,7 +2043,21 @@ func (s *StateStore) CSIVolumeClaim(index uint64, namespace, id string, alloc *s return fmt.Errorf("volume row conversion error") } - ws := memdb.NewWatchSet() + var alloc *structs.Allocation + if claim.Mode != structs.CSIVolumeClaimRelease { + alloc, err = s.AllocByID(ws, claim.AllocationID) + if err != nil { + s.logger.Error("AllocByID failed", "error", err) + return fmt.Errorf(structs.ErrUnknownAllocationPrefix) + } + if alloc == nil { + s.logger.Error("AllocByID failed to find alloc", "alloc_id", claim.AllocationID) + if err != nil { + return fmt.Errorf(structs.ErrUnknownAllocationPrefix) + } + } + } + volume, err := s.CSIVolumeDenormalizePlugins(ws, orig.Copy()) if err != nil { return err diff --git a/nomad/state/state_store_test.go b/nomad/state/state_store_test.go index 2a16fec7318d..220596104428 100644 --- a/nomad/state/state_store_test.go +++ b/nomad/state/state_store_test.go @@ -2941,18 +2941,33 @@ func TestStateStore_CSIVolume(t *testing.T) { vs = slurp(iter) require.Equal(t, 1, len(vs)) + // Allocs + a0 := mock.Alloc() + a1 := mock.Alloc() + index++ + err = state.UpsertAllocs(index, []*structs.Allocation{a0, a1}) + require.NoError(t, err) + // Claims - a0 := &structs.Allocation{ID: uuid.Generate()} - a1 := &structs.Allocation{ID: uuid.Generate()} r := structs.CSIVolumeClaimRead w := structs.CSIVolumeClaimWrite u := structs.CSIVolumeClaimRelease + claim0 := &structs.CSIVolumeClaim{ + AllocationID: a0.ID, + NodeID: node.ID, + Mode: r, + } + claim1 := &structs.CSIVolumeClaim{ + AllocationID: a1.ID, + NodeID: node.ID, + Mode: w, + } index++ - err = state.CSIVolumeClaim(index, ns, vol0, a0, r) + err = state.CSIVolumeClaim(index, ns, vol0, claim0) require.NoError(t, err) index++ - err = state.CSIVolumeClaim(index, ns, vol0, a1, w) + err = state.CSIVolumeClaim(index, ns, vol0, claim1) require.NoError(t, err) ws = memdb.NewWatchSet() @@ -2961,7 +2976,8 @@ func TestStateStore_CSIVolume(t *testing.T) { vs = slurp(iter) require.False(t, vs[0].WriteFreeClaims()) - err = state.CSIVolumeClaim(2, ns, vol0, a0, u) + claim0.Mode = u + err = state.CSIVolumeClaim(2, ns, vol0, claim0) require.NoError(t, err) ws = memdb.NewWatchSet() iter, err = state.CSIVolumesByPluginID(ws, ns, "minnie") @@ -2980,10 +2996,11 @@ func TestStateStore_CSIVolume(t *testing.T) { // release claims to unblock deregister index++ - err = state.CSIVolumeClaim(index, ns, vol0, a0, u) + err = state.CSIVolumeClaim(index, ns, vol0, claim0) require.NoError(t, err) index++ - err = state.CSIVolumeClaim(index, ns, vol0, a1, u) + claim1.Mode = u + err = state.CSIVolumeClaim(index, ns, vol0, claim1) require.NoError(t, err) index++ diff --git a/nomad/structs/csi.go b/nomad/structs/csi.go index 9d55a5a844bf..bea3439ea685 100644 --- a/nomad/structs/csi.go +++ b/nomad/structs/csi.go @@ -185,6 +185,22 @@ func (v *CSIMountOptions) GoString() string { return v.String() } +type CSIVolumeClaim struct { + AllocationID string + NodeID string + Mode CSIVolumeClaimMode + State CSIVolumeClaimState +} + +type CSIVolumeClaimState int + +const ( + CSIVolumeClaimStateTaken CSIVolumeClaimState = iota + CSIVolumeClaimStateNodeDetached + CSIVolumeClaimStateControllerDetached + CSIVolumeClaimStateReadyToFree +) + // CSIVolume is the full representation of a CSI Volume type CSIVolume struct { // ID is a namespace unique URL safe identifier for the volume @@ -200,8 +216,12 @@ type CSIVolume struct { MountOptions *CSIMountOptions // Allocations, tracking claim status - ReadAllocs map[string]*Allocation - WriteAllocs map[string]*Allocation + ReadAllocs map[string]*Allocation // AllocID -> Allocation + WriteAllocs map[string]*Allocation // AllocID -> Allocation + + ReadClaims map[string]*CSIVolumeClaim // AllocID -> claim + WriteClaims map[string]*CSIVolumeClaim // AllocID -> claim + PastClaims map[string]*CSIVolumeClaim // AllocID -> claim // Schedulable is true if all the denormalized plugin health fields are true, and the // volume has not been marked for garbage collection @@ -262,6 +282,10 @@ func (v *CSIVolume) newStructs() { v.ReadAllocs = map[string]*Allocation{} v.WriteAllocs = map[string]*Allocation{} + + v.ReadClaims = map[string]*CSIVolumeClaim{} + v.WriteClaims = map[string]*CSIVolumeClaim{} + v.PastClaims = map[string]*CSIVolumeClaim{} } func (v *CSIVolume) RemoteID() string { @@ -350,30 +374,43 @@ func (v *CSIVolume) Copy() *CSIVolume { out.WriteAllocs[k] = v } + for k, v := range v.ReadClaims { + claim := *v + out.ReadClaims[k] = &claim + } + for k, v := range v.WriteClaims { + claim := *v + out.WriteClaims[k] = &claim + } + for k, v := range v.PastClaims { + claim := *v + out.PastClaims[k] = &claim + } + return out } // Claim updates the allocations and changes the volume state -func (v *CSIVolume) Claim(claim CSIVolumeClaimMode, alloc *Allocation) error { - switch claim { +func (v *CSIVolume) Claim(claim *CSIVolumeClaim, alloc *Allocation) error { + switch claim.Mode { case CSIVolumeClaimRead: - return v.ClaimRead(alloc) + return v.ClaimRead(claim, alloc) case CSIVolumeClaimWrite: - return v.ClaimWrite(alloc) + return v.ClaimWrite(claim, alloc) case CSIVolumeClaimRelease: - return v.ClaimRelease(alloc) + return v.ClaimRelease(claim) } return nil } // ClaimRead marks an allocation as using a volume read-only -func (v *CSIVolume) ClaimRead(alloc *Allocation) error { - if alloc == nil { - return fmt.Errorf("allocation missing") - } - if _, ok := v.ReadAllocs[alloc.ID]; ok { +func (v *CSIVolume) ClaimRead(claim *CSIVolumeClaim, alloc *Allocation) error { + if _, ok := v.ReadAllocs[claim.AllocationID]; ok { return nil } + if alloc == nil { + return fmt.Errorf("allocation missing: %s", claim.AllocationID) + } if !v.ReadSchedulable() { return fmt.Errorf("unschedulable") @@ -381,19 +418,24 @@ func (v *CSIVolume) ClaimRead(alloc *Allocation) error { // Allocations are copy on write, so we want to keep the id but don't need the // pointer. We'll get it from the db in denormalize. - v.ReadAllocs[alloc.ID] = nil - delete(v.WriteAllocs, alloc.ID) + v.ReadAllocs[claim.AllocationID] = nil + delete(v.WriteAllocs, claim.AllocationID) + + v.ReadClaims[claim.AllocationID] = claim + delete(v.WriteClaims, claim.AllocationID) + delete(v.PastClaims, claim.AllocationID) + return nil } // ClaimWrite marks an allocation as using a volume as a writer -func (v *CSIVolume) ClaimWrite(alloc *Allocation) error { - if alloc == nil { - return fmt.Errorf("allocation missing") - } - if _, ok := v.WriteAllocs[alloc.ID]; ok { +func (v *CSIVolume) ClaimWrite(claim *CSIVolumeClaim, alloc *Allocation) error { + if _, ok := v.WriteAllocs[claim.AllocationID]; ok { return nil } + if alloc == nil { + return fmt.Errorf("allocation missing: %s", claim.AllocationID) + } if !v.WriteSchedulable() { return fmt.Errorf("unschedulable") @@ -412,16 +454,27 @@ func (v *CSIVolume) ClaimWrite(alloc *Allocation) error { // pointer. We'll get it from the db in denormalize. v.WriteAllocs[alloc.ID] = nil delete(v.ReadAllocs, alloc.ID) + + v.WriteClaims[alloc.ID] = claim + delete(v.ReadClaims, alloc.ID) + delete(v.PastClaims, alloc.ID) + return nil } -// ClaimRelease is called when the allocation has terminated and already stopped using the volume -func (v *CSIVolume) ClaimRelease(alloc *Allocation) error { - if alloc == nil { - return fmt.Errorf("allocation missing") +// ClaimRelease is called when the allocation has terminated and +// already stopped using the volume +func (v *CSIVolume) ClaimRelease(claim *CSIVolumeClaim) error { + delete(v.ReadAllocs, claim.AllocationID) + delete(v.WriteAllocs, claim.AllocationID) + delete(v.ReadClaims, claim.AllocationID) + delete(v.WriteClaims, claim.AllocationID) + + if claim.State == CSIVolumeClaimStateReadyToFree { + delete(v.PastClaims, claim.AllocationID) + } else { + v.PastClaims[claim.AllocationID] = claim } - delete(v.ReadAllocs, alloc.ID) - delete(v.WriteAllocs, alloc.ID) return nil } @@ -525,10 +578,21 @@ const ( type CSIVolumeClaimRequest struct { VolumeID string AllocationID string + NodeID string Claim CSIVolumeClaimMode + State CSIVolumeClaimState WriteRequest } +func (req *CSIVolumeClaimRequest) ToClaim() *CSIVolumeClaim { + return &CSIVolumeClaim{ + AllocationID: req.AllocationID, + NodeID: req.NodeID, + Mode: req.Claim, + State: req.State, + } +} + type CSIVolumeClaimResponse struct { // Opaque static publish properties of the volume. SP MAY use this // field to ensure subsequent `NodeStageVolume` or `NodePublishVolume` diff --git a/nomad/structs/csi_test.go b/nomad/structs/csi_test.go index 7f74afb0dbfc..bc267ad7e3cc 100644 --- a/nomad/structs/csi_test.go +++ b/nomad/structs/csi_test.go @@ -12,17 +12,23 @@ func TestCSIVolumeClaim(t *testing.T) { vol.Schedulable = true alloc := &Allocation{ID: "a1", Namespace: "n", JobID: "j"} + claim := &CSIVolumeClaim{ + AllocationID: alloc.ID, + NodeID: "foo", + Mode: CSIVolumeClaimRead, + } - require.NoError(t, vol.ClaimRead(alloc)) + require.NoError(t, vol.ClaimRead(claim, alloc)) require.True(t, vol.ReadSchedulable()) require.True(t, vol.WriteSchedulable()) - require.NoError(t, vol.ClaimRead(alloc)) + require.NoError(t, vol.ClaimRead(claim, alloc)) - require.NoError(t, vol.ClaimWrite(alloc)) + claim.Mode = CSIVolumeClaimWrite + require.NoError(t, vol.ClaimWrite(claim, alloc)) require.True(t, vol.ReadSchedulable()) require.False(t, vol.WriteFreeClaims()) - vol.ClaimRelease(alloc) + vol.ClaimRelease(claim) require.True(t, vol.ReadSchedulable()) require.True(t, vol.WriteFreeClaims()) }