diff --git a/.changelog/17996.txt b/.changelog/17996.txt new file mode 100644 index 000000000000..f95501a253dc --- /dev/null +++ b/.changelog/17996.txt @@ -0,0 +1,3 @@ +```release-note:bug +csi: Fixed a bug in sending concurrent requests to CSI controller plugins by serializing them per plugin +``` diff --git a/nomad/client_csi_endpoint.go b/nomad/client_csi_endpoint.go index f0091435a8fc..773efa211717 100644 --- a/nomad/client_csi_endpoint.go +++ b/nomad/client_csi_endpoint.go @@ -5,7 +5,7 @@ package nomad import ( "fmt" - "math/rand" + "sort" "strings" "time" @@ -262,9 +262,9 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) { ws := memdb.NewWatchSet() - // note: plugin IDs are not scoped to region/DC but volumes are. - // so any node we get for a controller is already in the same - // region/DC for the volume. + // note: plugin IDs are not scoped to region but volumes are. so any Nomad + // client we get for a controller is already in the same region for the + // volume. plugin, err := snap.CSIPluginByID(ws, pluginID) if err != nil { return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err) @@ -273,13 +273,10 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) { return nil, fmt.Errorf("plugin missing: %s", pluginID) } - // iterating maps is "random" but unspecified and isn't particularly - // random with small maps, so not well-suited for load balancing. - // so we shuffle the keys and iterate over them. clientIDs := []string{} for clientID, controller := range plugin.Controllers { - if !controller.IsController() { + if !controller.IsController() || !controller.Healthy { // we don't have separate types for CSIInfo depending on // whether it's a controller or node. this error shouldn't // make it to production but is to aid developers during @@ -295,9 +292,11 @@ func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) { return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID) } - rand.Shuffle(len(clientIDs), func(i, j int) { - clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i] - }) + // Many plugins don't handle concurrent requests as described in the spec, + // and have undocumented expectations of using k8s-specific sidecars to + // leader elect. Sort the client IDs so that we prefer sending requests to + // the same controller to hack around this. + clientIDs = sort.StringSlice(clientIDs) return clientIDs, nil } diff --git a/nomad/csi_endpoint.go b/nomad/csi_endpoint.go index 0614b87b650d..54be55d7366b 100644 --- a/nomad/csi_endpoint.go +++ b/nomad/csi_endpoint.go @@ -4,6 +4,7 @@ package nomad import ( + "context" "fmt" "net/http" "strings" @@ -549,7 +550,9 @@ func (v *CSIVolume) controllerPublishVolume(req *structs.CSIVolumeClaimRequest, cReq.PluginID = plug.ID cResp := &cstructs.ClientCSIControllerAttachVolumeResponse{} - err = v.srv.RPC(method, cReq, cResp) + err = v.serializedControllerRPC(plug.ID, func() error { + return v.srv.RPC(method, cReq, cResp) + }) if err != nil { if strings.Contains(err.Error(), "FailedPrecondition") { return fmt.Errorf("%v: %v", structs.ErrCSIClientRPCRetryable, err) @@ -586,6 +589,57 @@ func (v *CSIVolume) volAndPluginLookup(namespace, volID string) (*structs.CSIPlu return plug, vol, nil } +// serializedControllerRPC ensures we're only sending a single controller RPC to +// a given plugin if the RPC can cause conflicting state changes. +// +// The CSI specification says that we SHOULD send no more than one in-flight +// request per *volume* at a time, with an allowance for losing state +// (ex. leadership transitions) which the plugins SHOULD handle gracefully. +// +// In practice many CSI plugins rely on k8s-specific sidecars for serializing +// storage provider API calls globally (ex. concurrently attaching EBS volumes +// to an EC2 instance results in a race for device names). So we have to be much +// more conservative about concurrency in Nomad than the spec allows. +func (v *CSIVolume) serializedControllerRPC(pluginID string, fn func() error) error { + + for { + v.srv.volumeControllerLock.Lock() + future := v.srv.volumeControllerFutures[pluginID] + if future == nil { + future, futureDone := context.WithCancel(v.srv.shutdownCtx) + v.srv.volumeControllerFutures[pluginID] = future + v.srv.volumeControllerLock.Unlock() + + err := fn() + + // close the future while holding the lock and not in a defer so + // that we can ensure we've cleared it from the map before allowing + // anyone else to take the lock and write a new one + v.srv.volumeControllerLock.Lock() + futureDone() + delete(v.srv.volumeControllerFutures, pluginID) + v.srv.volumeControllerLock.Unlock() + + return err + } else { + v.srv.volumeControllerLock.Unlock() + + select { + case <-future.Done(): + continue + case <-v.srv.shutdownCh: + // The csi_hook publish workflow on the client will retry if it + // gets this error. On unpublish, we don't want to block client + // shutdown so we give up on error. The new leader's + // volumewatcher will iterate all the claims at startup to + // detect this and mop up any claims in the NodeDetached state + // (volume GC will run periodically as well) + return structs.ErrNoLeader + } + } + } +} + // allowCSIMount is called on Job register to check mount permission func allowCSIMount(aclObj *acl.ACL, namespace string) bool { return aclObj.AllowPluginRead() && @@ -863,8 +917,11 @@ func (v *CSIVolume) controllerUnpublishVolume(vol *structs.CSIVolume, claim *str Secrets: vol.Secrets, } req.PluginID = vol.PluginID - err = v.srv.RPC("ClientCSI.ControllerDetachVolume", req, - &cstructs.ClientCSIControllerDetachVolumeResponse{}) + + err = v.serializedControllerRPC(vol.PluginID, func() error { + return v.srv.RPC("ClientCSI.ControllerDetachVolume", req, + &cstructs.ClientCSIControllerDetachVolumeResponse{}) + }) if err != nil { return fmt.Errorf("could not detach from controller: %v", err) } @@ -1139,7 +1196,9 @@ func (v *CSIVolume) deleteVolume(vol *structs.CSIVolume, plugin *structs.CSIPlug cReq.PluginID = plugin.ID cResp := &cstructs.ClientCSIControllerDeleteVolumeResponse{} - return v.srv.RPC(method, cReq, cResp) + return v.serializedControllerRPC(plugin.ID, func() error { + return v.srv.RPC(method, cReq, cResp) + }) } func (v *CSIVolume) ListExternal(args *structs.CSIVolumeExternalListRequest, reply *structs.CSIVolumeExternalListResponse) error { @@ -1286,7 +1345,9 @@ func (v *CSIVolume) CreateSnapshot(args *structs.CSISnapshotCreateRequest, reply } cReq.PluginID = pluginID cResp := &cstructs.ClientCSIControllerCreateSnapshotResponse{} - err = v.srv.RPC(method, cReq, cResp) + err = v.serializedControllerRPC(pluginID, func() error { + return v.srv.RPC(method, cReq, cResp) + }) if err != nil { multierror.Append(&mErr, fmt.Errorf("could not create snapshot: %v", err)) continue @@ -1360,7 +1421,9 @@ func (v *CSIVolume) DeleteSnapshot(args *structs.CSISnapshotDeleteRequest, reply cReq := &cstructs.ClientCSIControllerDeleteSnapshotRequest{ID: snap.ID} cReq.PluginID = plugin.ID cResp := &cstructs.ClientCSIControllerDeleteSnapshotResponse{} - err = v.srv.RPC(method, cReq, cResp) + err = v.serializedControllerRPC(plugin.ID, func() error { + return v.srv.RPC(method, cReq, cResp) + }) if err != nil { multierror.Append(&mErr, fmt.Errorf("could not delete %q: %v", snap.ID, err)) } diff --git a/nomad/csi_endpoint_test.go b/nomad/csi_endpoint_test.go index f2482b608d6e..08cbbad412c6 100644 --- a/nomad/csi_endpoint_test.go +++ b/nomad/csi_endpoint_test.go @@ -6,6 +6,7 @@ package nomad import ( "fmt" "strings" + "sync" "testing" "time" @@ -21,6 +22,7 @@ import ( cconfig "github.com/hashicorp/nomad/client/config" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/lib/lang" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/state" "github.com/hashicorp/nomad/nomad/structs" @@ -1971,3 +1973,49 @@ func TestCSI_RPCVolumeAndPluginLookup(t *testing.T) { require.Nil(t, vol) require.EqualError(t, err, fmt.Sprintf("volume not found: %s", id2)) } + +func TestCSI_SerializedControllerRPC(t *testing.T) { + ci.Parallel(t) + + srv, shutdown := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) + defer shutdown() + testutil.WaitForLeader(t, srv.RPC) + + var wg sync.WaitGroup + wg.Add(3) + + timeCh := make(chan lang.Pair[string, time.Duration]) + + testFn := func(pluginID string, dur time.Duration) { + defer wg.Done() + c := NewCSIVolumeEndpoint(srv, nil) + now := time.Now() + err := c.serializedControllerRPC(pluginID, func() error { + time.Sleep(dur) + return nil + }) + elapsed := time.Since(now) + timeCh <- lang.Pair[string, time.Duration]{pluginID, elapsed} + must.NoError(t, err) + } + + go testFn("plugin1", 50*time.Millisecond) + go testFn("plugin2", 50*time.Millisecond) + go testFn("plugin1", 50*time.Millisecond) + + totals := map[string]time.Duration{} + for i := 0; i < 3; i++ { + pair := <-timeCh + totals[pair.First] += pair.Second + } + + wg.Wait() + + // plugin1 RPCs should block each other + must.GreaterEq(t, 150*time.Millisecond, totals["plugin1"]) + must.Less(t, 200*time.Millisecond, totals["plugin1"]) + + // plugin1 RPCs should not block plugin2 RPCs + must.GreaterEq(t, 50*time.Millisecond, totals["plugin2"]) + must.Less(t, 100*time.Millisecond, totals["plugin2"]) +} diff --git a/nomad/server.go b/nomad/server.go index 371639d06436..67ded0aa343b 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -218,6 +218,13 @@ type Server struct { // volumeWatcher is used to release volume claims volumeWatcher *volumewatcher.Watcher + // volumeControllerFutures is a map of plugin IDs to pending controller RPCs. If + // no RPC is pending for a given plugin, this may be nil. + volumeControllerFutures map[string]context.Context + + // volumeControllerLock synchronizes access controllerFutures map + volumeControllerLock sync.Mutex + // keyringReplicator is used to replicate root encryption keys from the // leader keyringReplicator *KeyringReplicator @@ -445,6 +452,7 @@ func NewServer(config *Config, consulCatalog consul.CatalogAPI, consulConfigEntr s.logger.Error("failed to create volume watcher", "error", err) return nil, fmt.Errorf("failed to create volume watcher: %v", err) } + s.volumeControllerFutures = map[string]context.Context{} // Start the eval broker notification system so any subscribers can get // updates when the processes SetEnabled is triggered.