diff --git a/client/allocrunner/taskrunner/plugin_supervisor_hook.go b/client/allocrunner/taskrunner/plugin_supervisor_hook.go index 9adb8d53bf50..7cea2522ffd0 100644 --- a/client/allocrunner/taskrunner/plugin_supervisor_hook.go +++ b/client/allocrunner/taskrunner/plugin_supervisor_hook.go @@ -203,6 +203,13 @@ func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) { socketPath := filepath.Join(h.mountPoint, structs.CSISocketName) t := time.NewTimer(0) + var client csi.CSIPlugin + defer func() { + if client != nil { + client.Close() + } + }() + // Step 1: Wait for the plugin to initially become available. WAITFORREADY: for { @@ -210,9 +217,9 @@ WAITFORREADY: case <-ctx.Done(): return case <-t.C: - pluginHealthy, err := h.supervisorLoopOnce(ctx, socketPath) + pluginHealthy, err := h.supervisorLoopOnce(ctx, client, socketPath) if err != nil || !pluginHealthy { - h.logger.Debug("CSI Plugin not ready", "error", err) + h.logger.Debug("CSI plugin not ready", "error", err) // Plugin is not yet returning healthy, because we want to optimise for // quickly bringing a plugin online, we use a short timeout here. @@ -232,9 +239,9 @@ WAITFORREADY: } // Step 2: Register the plugin with the catalog. - deregisterPluginFn, err := h.registerPlugin(socketPath) + deregisterPluginFn, err := h.registerPlugin(client, socketPath) if err != nil { - h.logger.Error("CSI Plugin registration failed", "error", err) + h.logger.Error("CSI plugin registration failed", "error", err) event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) event.SetMessage(fmt.Sprintf("failed to register plugin: %s, reason: %v", h.task.CSIPluginConfig.ID, err)) h.eventEmitter.EmitEvent(event) @@ -249,9 +256,9 @@ WAITFORREADY: deregisterPluginFn() return case <-t.C: - pluginHealthy, err := h.supervisorLoopOnce(ctx, socketPath) + pluginHealthy, err := h.supervisorLoopOnce(ctx, client, socketPath) if err != nil { - h.logger.Error("CSI Plugin fingerprinting failed", "error", err) + h.logger.Error("CSI plugin fingerprinting failed", "error", err) } // The plugin has transitioned to a healthy state. Emit an event. @@ -281,16 +288,9 @@ WAITFORREADY: } } -func (h *csiPluginSupervisorHook) registerPlugin(socketPath string) (func(), error) { - +func (h *csiPluginSupervisorHook) registerPlugin(client csi.CSIPlugin, socketPath string) (func(), error) { // At this point we know the plugin is ready and we can fingerprint it // to get its vendor name and version - client, err := csi.NewClient(socketPath, h.logger.Named("csi_client").With("plugin.name", h.task.CSIPluginConfig.ID, "plugin.type", h.task.CSIPluginConfig.Type)) - if err != nil { - return nil, fmt.Errorf("failed to create csi client: %v", err) - } - defer client.Close() - info, err := client.PluginInfo() if err != nil { return nil, fmt.Errorf("failed to probe plugin: %v", err) @@ -354,17 +354,16 @@ func (h *csiPluginSupervisorHook) registerPlugin(socketPath string) (func(), err }, nil } -func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, socketPath string) (bool, error) { +func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, client csi.CSIPlugin, socketPath string) (bool, error) { _, err := os.Stat(socketPath) if err != nil { return false, fmt.Errorf("failed to stat socket: %v", err) } - client, err := csi.NewClient(socketPath, h.logger.Named("csi_client").With("plugin.name", h.task.CSIPluginConfig.ID, "plugin.type", h.task.CSIPluginConfig.Type)) + client, err = h.newClient(socketPath) if err != nil { return false, fmt.Errorf("failed to create csi client: %v", err) } - defer client.Close() healthy, err := client.PluginProbe(ctx) if err != nil { @@ -374,6 +373,12 @@ func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, socket return healthy, nil } +func (h *csiPluginSupervisorHook) newClient(socketPath string) (csi.CSIPlugin, error) { + return csi.NewClient(socketPath, h.logger.Named("csi_client").With( + "plugin.name", h.task.CSIPluginConfig.ID, + "plugin.type", h.task.CSIPluginConfig.Type)) +} + // Stop is called after the task has exited and will not be started // again. It is the only hook guaranteed to be executed whenever // TaskRunner.Run is called (and not gracefully shutting down).