diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index 89de26b616e0..60aeb545a614 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -231,7 +231,7 @@ func (v *CSIVolume) Get(args *structs.CSIVolumeGetRequest, reply *structs.CSIVol return v.srv.blockingRPC(&opts) } -func (v *CSIVolume) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, vol *structs.CSIVolume) (*structs.CSIPlugin, error) { +func (v *CSIVolume) pluginValidateVolume(vol *structs.CSIVolume) (*structs.CSIPlugin, error) { state := v.srv.fsm.State() plugin, err := state.CSIPluginByID(nil, vol.PluginID) @@ -242,6 +242,10 @@ func (v *CSIVolume) pluginValidateVolume(req *structs.CSIVolumeRegisterRequest, return nil, fmt.Errorf("no CSI plugin named: %s could be found", vol.PluginID) } + if plugin.ControllerRequired && plugin.ControllersHealthy < 1 { + return nil, fmt.Errorf("no healthy controllers for CSI plugin: %s", vol.PluginID) + } + vol.Provider = plugin.Provider vol.ProviderVersion = plugin.Version @@ -330,6 +334,11 @@ func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *stru return err } + plugin, err := v.pluginValidateVolume(vol) + if err != nil { + return err + } + // CSIVolume has many user-defined fields which are immutable // once set, and many fields that are controlled by Nomad and // are not user-settable. We merge onto a copy of the existing @@ -342,10 +351,6 @@ func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *stru existingVol = existingVol.Copy() // reconcile mutable fields - plugin, err := snap.CSIPluginByID(ws, existingVol.PluginID) - if err != nil { - return fmt.Errorf("unable to update volume: %s", err) - } if err = v.reconcileVolume(plugin, existingVol, vol); err != nil { return fmt.Errorf("unable to update volume: %s", err) } @@ -361,10 +366,6 @@ func (v *CSIVolume) Register(args *structs.CSIVolumeRegisterRequest, reply *stru } } - plugin, err := v.pluginValidateVolume(args, vol) - if err != nil { - return err - } if err := v.controllerValidateVolume(args, vol, plugin); err != nil { return err } @@ -1070,7 +1071,7 @@ func (v *CSIVolume) Create(args *structs.CSIVolumeCreateRequest, reply *structs. if err = vol.Validate(); err != nil { return err } - plugin, err := v.pluginValidateVolume(regArgs, vol) + plugin, err := v.pluginValidateVolume(vol) if err != nil { return err } @@ -1246,8 +1247,6 @@ func (v *CSIVolume) expandVolume(vol *structs.CSIVolume, plugin *structs.CSIPlug return nil } - // TODO: this can happen when the controller just hasn't been fully recovered yet... - // TODO: `register` and `create` do different errors during startup... if !plugin.HasControllerCapability(structs.CSIControllerSupportsExpand) { return errors.New("expand is not implemented by this controller plugin") } diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index 1b48d8ad4ae9..dce281428b33 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -127,6 +127,83 @@ func TestCSIVolumeEndpoint_Get_ACL(t *testing.T) { require.Equal(t, vols[0].ID, resp.Volume.ID) } +func TestCSIVolume_pluginValidateVolume(t *testing.T) { + // bare minimum server for this method + store := state.TestStateStore(t) + srv := &Server{ + fsm: &nomadFSM{state: store}, + } + // has our method under test + csiVolume := &CSIVolume{srv: srv} + // volume for which we will request a valid plugin + vol := &structs.CSIVolume{PluginID: "neat-plugin"} + + // plugin not found + got, err := csiVolume.pluginValidateVolume(vol) + must.Nil(t, got, must.Sprint("nonexistent plugin should be nil")) + must.ErrorContains(t, err, "no CSI plugin named") + + // we'll upsert this plugin after optionally modifying it + basePlug := &structs.CSIPlugin{ + ID: vol.PluginID, + // these should be set on the volume after success + Provider: "neat-provider", + Version: "v0", + // explicit zero values, because these modify behavior we care about + ControllerRequired: false, + ControllersHealthy: 0, + } + + cases := []struct { + name string + updatePlugin func(*structs.CSIPlugin) + expectErr string + }{ + { + name: "controller not required", + }, + { + name: "controller unhealthy", + updatePlugin: func(p *structs.CSIPlugin) { + p.ControllerRequired = true + }, + expectErr: "no healthy controllers", + }, + { + name: "controller healthy", + updatePlugin: func(p *structs.CSIPlugin) { + p.ControllerRequired = true + p.ControllersHealthy = 1 + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + vol := vol.Copy() + plug := basePlug.Copy() + + if tc.updatePlugin != nil { + tc.updatePlugin(plug) + } + must.NoError(t, store.UpsertCSIPlugin(1000, plug)) + + got, err := csiVolume.pluginValidateVolume(vol) + + if tc.expectErr == "" { + must.NoError(t, err) + must.NotNil(t, got, must.Sprint("plugin should not be nil")) + must.Eq(t, vol.Provider, plug.Provider) + must.Eq(t, vol.ProviderVersion, plug.Version) + } else { + must.Error(t, err, must.Sprint("expect error:", tc.expectErr)) + must.ErrorContains(t, err, tc.expectErr) + must.Nil(t, got, must.Sprint("plugin should be nil")) + } + }) + } +} + func TestCSIVolumeEndpoint_Register(t *testing.T) { ci.Parallel(t) srv, shutdown := TestServer(t, func(c *Config) {