diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index 510871981039..a8695982be2f 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -787,6 +787,7 @@ func (v *CSIVolume) nodeUnpublishVolumeImpl(vol *structs.CSIVolume, claim *struc // be called on a copy of the volume. func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim) error { v.logger.Trace("controller unpublish", "vol", vol.ID) + if !vol.ControllerRequired { claim.State = structs.CSIVolumeClaimStateReadyToFree return nil @@ -801,26 +802,39 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str } else if plugin == nil { return fmt.Errorf("no such plugin: %q", vol.PluginID) } + if !plugin.HasControllerCapability(structs.CSIControllerSupportsAttachDetach) { + claim.State = structs.CSIVolumeClaimStateReadyToFree return nil } - // we only send a controller detach if a Nomad client no longer has - // any claim to the volume, so we need to check the status of claimed - // allocations vol, err = state.CSIVolumeDenormalize(ws, vol) if err != nil { return err } - for _, alloc := range vol.ReadAllocs { - if alloc != nil && alloc.NodeID == claim.NodeID && !alloc.TerminalStatus() { + + // we only send a controller detach if a Nomad client no longer has any + // claim to the volume, so we need to check the status of any other claimed + // allocations + shouldCancel := func(alloc *structs.Allocation) bool { + if alloc != nil && alloc.ID != claim.AllocationID && + alloc.NodeID == claim.NodeID && !alloc.TerminalStatus() { claim.State = structs.CSIVolumeClaimStateReadyToFree + v.logger.Debug( + "controller unpublish canceled: another non-terminal alloc is on this node", + "vol", vol.ID, "alloc", alloc.ID) + return true + } + return false + } + + for _, alloc := range vol.ReadAllocs { + if shouldCancel(alloc) { return nil } } for _, alloc := range vol.WriteAllocs { - if alloc != nil && alloc.NodeID == claim.NodeID && !alloc.TerminalStatus() { - claim.State = structs.CSIVolumeClaimStateReadyToFree + if shouldCancel(alloc) { return nil } } @@ -846,6 +860,8 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str if err != nil { return fmt.Errorf("could not detach from controller: %v", err) } + + v.logger.Trace("controller detach complete", "vol", vol.ID) claim.State = structs.CSIVolumeClaimStateReadyToFree return v.checkpointClaim(vol, claim) } diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 65e661d35e21..9f24fc12c3b4 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -7,6 +7,11 @@ import ( "time" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" + "github.com/shoenig/test" + "github.com/shoenig/test/must" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/hashicorp/nomad/acl" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client" @@ -17,7 +22,6 @@ import ( "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" - "github.com/stretchr/testify/require" ) func TestCSIVolumeEndpoint_Get(t *testing.T) { @@ -499,12 +503,14 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) { }, } index++ - require.NoError(t, state.UpsertNode(structs.MsgTypeTestSetup, index, node)) + must.NoError(t, state.UpsertNode(structs.MsgTypeTestSetup, index, node)) type tc struct { name string startingState structs.CSIVolumeClaimState + endState structs.CSIVolumeClaimState nodeID string + otherNodeID string expectedErrMsg string } testCases := []tc{ @@ -512,24 +518,37 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) { name: "success", startingState: structs.CSIVolumeClaimStateControllerDetached, nodeID: node.ID, + otherNodeID: uuid.Generate(), + }, + { + name: "non-terminal allocation on same node", + startingState: structs.CSIVolumeClaimStateNodeDetached, + nodeID: node.ID, + otherNodeID: node.ID, }, { name: "unpublish previously detached node", startingState: structs.CSIVolumeClaimStateNodeDetached, + endState: structs.CSIVolumeClaimStateNodeDetached, expectedErrMsg: "could not detach from controller: controller detach volume: No path to node", nodeID: node.ID, + otherNodeID: uuid.Generate(), }, { name: "unpublish claim on garbage collected node", startingState: structs.CSIVolumeClaimStateTaken, + endState: structs.CSIVolumeClaimStateNodeDetached, expectedErrMsg: "could not detach from controller: controller detach volume: No path to node", nodeID: uuid.Generate(), + otherNodeID: uuid.Generate(), }, { name: "first unpublish", startingState: structs.CSIVolumeClaimStateTaken, + endState: structs.CSIVolumeClaimStateNodeDetached, expectedErrMsg: "could not detach from controller: controller detach volume: No path to node", nodeID: node.ID, + otherNodeID: uuid.Generate(), }, } @@ -551,15 +570,20 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) { index++ err = state.UpsertCSIVolume(index, []*structs.CSIVolume{vol}) - require.NoError(t, err) + must.NoError(t, err) // setup: create an alloc that will claim our volume alloc := mock.BatchAlloc() alloc.NodeID = tc.nodeID alloc.ClientStatus = structs.AllocClientStatusFailed + otherAlloc := mock.BatchAlloc() + otherAlloc.NodeID = tc.otherNodeID + otherAlloc.ClientStatus = structs.AllocClientStatusRunning + index++ - require.NoError(t, state.UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc})) + must.NoError(t, state.UpsertAllocs(structs.MsgTypeTestSetup, index, + []*structs.Allocation{alloc, otherAlloc})) // setup: claim the volume for our alloc claim := &structs.CSIVolumeClaim{ @@ -572,7 +596,20 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) { index++ claim.State = structs.CSIVolumeClaimStateTaken err = state.CSIVolumeClaim(index, ns, volID, claim) - require.NoError(t, err) + must.NoError(t, err) + + // setup: claim the volume for our other alloc + otherClaim := &structs.CSIVolumeClaim{ + AllocationID: otherAlloc.ID, + NodeID: tc.otherNodeID, + ExternalNodeID: "i-example", + Mode: structs.CSIVolumeClaimRead, + } + + index++ + otherClaim.State = structs.CSIVolumeClaimStateTaken + err = state.CSIVolumeClaim(index, ns, volID, otherClaim) + must.NoError(t, err) // test: unpublish and check the results claim.State = tc.startingState @@ -589,17 +626,23 @@ func TestCSIVolumeEndpoint_Unpublish(t *testing.T) { err = msgpackrpc.CallWithCodec(codec, "CSIVolume.Unpublish", req, &structs.CSIVolumeUnpublishResponse{}) + vol, volErr := state.CSIVolumeByID(nil, ns, volID) + must.NoError(t, volErr) + must.NotNil(t, vol) + if tc.expectedErrMsg == "" { - require.NoError(t, err) - vol, err = state.CSIVolumeByID(nil, ns, volID) - require.NoError(t, err) - require.NotNil(t, vol) - require.Len(t, vol.ReadAllocs, 0) + must.NoError(t, err) + assert.Len(t, vol.ReadAllocs, 1) } else { - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), tc.expectedErrMsg), - "error message %q did not contain %q", err.Error(), tc.expectedErrMsg) + must.Error(t, err) + assert.Len(t, vol.ReadAllocs, 2) + test.True(t, strings.Contains(err.Error(), tc.expectedErrMsg), + test.Sprintf("error %v did not contain %q", err, tc.expectedErrMsg)) + claim = vol.PastClaims[alloc.ID] + must.NotNil(t, claim) + test.Eq(t, tc.endState, claim.State) } + }) }