diff --git a/nomad/client_csi_endpoint.go b/nomad/client_csi_endpoint.go index 8dfe85e5503e..92cf6589c9d9 100644 --- a/nomad/client_csi_endpoint.go +++ b/nomad/client_csi_endpoint.go @@ -3,6 +3,7 @@ package nomad import ( "fmt" "math/rand" + "strings" "time" metrics "github.com/armon/go-metrics" @@ -20,74 +21,101 @@ type ClientCSI struct { func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAttachVolumeRequest, reply *cstructs.ClientCSIControllerAttachVolumeResponse) error { defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "attach_volume"}, time.Now()) - // Get a Nomad client node for the controller - nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID) + + clientIDs, err := a.clientIDsForController(args.PluginID) if err != nil { - return err + return fmt.Errorf("controller attach volume: %v", err) } - args.ControllerNodeID = nodeID - // Get the connection to the client - state, ok := a.srv.getNodeConn(args.ControllerNodeID) - if !ok { - return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerAttachVolume", args, reply) - } + for _, clientID := range clientIDs { + args.ControllerNodeID = clientID + state, ok := a.srv.getNodeConn(clientID) + if !ok { + return findNodeConnAndForward(a.srv, + clientID, "ClientCSI.ControllerAttachVolume", args, reply) + } - // Make the RPC - err = NodeRpc(state.Session, "CSI.ControllerAttachVolume", args, reply) - if err != nil { + err = NodeRpc(state.Session, "CSI.ControllerAttachVolume", args, reply) + if err == nil { + return nil + } + if a.isRetryable(err, clientID, args.PluginID) { + a.logger.Debug("failed to reach controller on client %q: %v", clientID, err) + continue + } return fmt.Errorf("controller attach volume: %v", err) } - return nil + return fmt.Errorf("controller attach volume: %v", err) } func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerValidateVolumeRequest, reply *cstructs.ClientCSIControllerValidateVolumeResponse) error { defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "validate_volume"}, time.Now()) - // Get a Nomad client node for the controller - nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID) + clientIDs, err := a.clientIDsForController(args.PluginID) if err != nil { - return err + return fmt.Errorf("validate volume: %v", err) } - args.ControllerNodeID = nodeID - // Get the connection to the client - state, ok := a.srv.getNodeConn(args.ControllerNodeID) - if !ok { - return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerValidateVolume", args, reply) - } + for _, clientID := range clientIDs { + args.ControllerNodeID = clientID + state, ok := a.srv.getNodeConn(clientID) + if !ok { + return findNodeConnAndForward(a.srv, + clientID, "ClientCSI.ControllerValidateVolume", args, reply) + } - // Make the RPC - err = NodeRpc(state.Session, "CSI.ControllerValidateVolume", args, reply) - if err != nil { + err = NodeRpc(state.Session, "CSI.ControllerValidateVolume", args, reply) + if err == nil { + return nil + } + if a.isRetryable(err, clientID, args.PluginID) { + a.logger.Debug("failed to reach controller on client %q: %v", clientID, err) + continue + } return fmt.Errorf("validate volume: %v", err) } - return nil + return fmt.Errorf("validate volume: %v", err) } func (a *ClientCSI) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error { defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "detach_volume"}, time.Now()) - // Get a Nomad client node for the controller - nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID) + clientIDs, err := a.clientIDsForController(args.PluginID) if err != nil { - return err + return fmt.Errorf("controller detach volume: %v", err) } - args.ControllerNodeID = nodeID - // Get the connection to the client - state, ok := a.srv.getNodeConn(args.ControllerNodeID) - if !ok { - return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerDetachVolume", args, reply) - } + for _, clientID := range clientIDs { + args.ControllerNodeID = clientID + state, ok := a.srv.getNodeConn(clientID) + if !ok { + return findNodeConnAndForward(a.srv, + clientID, "ClientCSI.ControllerDetachVolume", args, reply) + } - // Make the RPC - err = NodeRpc(state.Session, "CSI.ControllerDetachVolume", args, reply) - if err != nil { + err = NodeRpc(state.Session, "CSI.ControllerDetachVolume", args, reply) + if err == nil { + return nil + } + if a.isRetryable(err, clientID, args.PluginID) { + a.logger.Debug("failed to reach controller on client %q: %v", clientID, err) + continue + } return fmt.Errorf("controller detach volume: %v", err) } - return nil + return fmt.Errorf("controller detach volume: %v", err) +} +// we can retry the same RPC on a different controller in the cases where the +// client has stopped and been GC'd, or where the controller has stopped but +// we don't have the fingerprint update yet +func (a *ClientCSI) isRetryable(err error, clientID, pluginID string) bool { + // TODO(tgross): it would be nicer to use errors.Is here but we + // need to make sure we're using error wrapping to make that work + errMsg := err.Error() + return strings.Contains(errMsg, fmt.Sprintf("Unknown node: %s", clientID)) || + strings.Contains(errMsg, "no plugins registered for type: csi-controller") || + strings.Contains(errMsg, fmt.Sprintf("plugin %s for type controller not found", pluginID)) } func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error { @@ -119,29 +147,17 @@ func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeReq } -// nodeForController validates that the Nomad client node ID for -// a plugin exists and is new enough to support client RPC. If no node -// ID is passed, select a random node ID for the controller to load-balance -// long blocking RPCs across client nodes. -func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) { +// clientIDsForController returns a shuffled list of client IDs where the +// controller plugin is expected to be running. +func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) { snap, err := a.srv.State().Snapshot() if err != nil { - return "", err - } - - if nodeID != "" { - _, err = getNodeForRpc(snap, nodeID) - if err == nil { - return nodeID, nil - } else { - // we'll fall-through and select a node at random - a.logger.Trace("could not be used for client RPC", "node", nodeID, "error", err) - } + return nil, err } if pluginID == "" { - return "", fmt.Errorf("missing plugin ID") + return nil, fmt.Errorf("missing plugin ID") } ws := memdb.NewWatchSet() @@ -151,43 +167,37 @@ func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) { // region/DC for the volume. plugin, err := snap.CSIPluginByID(ws, pluginID) if err != nil { - return "", fmt.Errorf("error getting plugin: %s, %v", pluginID, err) + return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err) } if plugin == nil { - return "", fmt.Errorf("plugin missing: %s %v", pluginID, err) - } - count := len(plugin.Controllers) - if count == 0 { - return "", fmt.Errorf("no controllers available for plugin %q", plugin.ID) + return nil, fmt.Errorf("plugin missing: %s", pluginID) } // iterating maps is "random" but unspecified and isn't particularly // random with small maps, so not well-suited for load balancing. // so we shuffle the keys and iterate over them. - clientIDs := make([]string, 0, count) - for clientID := range plugin.Controllers { - clientIDs = append(clientIDs, clientID) - } - rand.Shuffle(count, func(i, j int) { - clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i] - }) + clientIDs := []string{} - for _, clientID := range clientIDs { - controller := plugin.Controllers[clientID] + for clientID, controller := range plugin.Controllers { if !controller.IsController() { // we don't have separate types for CSIInfo depending on // whether it's a controller or node. this error shouldn't // make it to production but is to aid developers during // development - err = fmt.Errorf("plugin is not a controller") continue } _, err = getNodeForRpc(snap, clientID) - if err != nil { - continue + if err == nil { + clientIDs = append(clientIDs, clientID) } - return clientID, nil } + if len(clientIDs) == 0 { + return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID) + } + + rand.Shuffle(len(clientIDs), func(i, j int) { + clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i] + }) - return "", err + return clientIDs, nil } diff --git a/nomad/client_csi_endpoint_test.go b/nomad/client_csi_endpoint_test.go index a3345613e043..fbc6e45a778a 100644 --- a/nomad/client_csi_endpoint_test.go +++ b/nomad/client_csi_endpoint_test.go @@ -2,7 +2,9 @@ package nomad import ( "fmt" + "net/rpc" "testing" + "time" memdb "github.com/hashicorp/go-memdb" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" @@ -19,137 +21,99 @@ import ( func TestClientCSIController_AttachVolume_Local(t *testing.T) { t.Parallel() require := require.New(t) - - // Start a server and client - s, cleanupS := TestServer(t, nil) - defer cleanupS() - codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) - - c, cleanupC := client.TestClient(t, func(c *config.Config) { - c.Servers = []string{s.config.RPCAddr.String()} - }) - defer cleanupC() - - waitForNodes(t, s, 1) + codec, cleanup := setupLocal(t) + defer cleanup() req := &cstructs.ClientCSIControllerAttachVolumeRequest{ - CSIControllerQuery: cstructs.CSIControllerQuery{ControllerNodeID: c.NodeID()}, + CSIControllerQuery: cstructs.CSIControllerQuery{PluginID: "minnie"}, } - // Fetch the response var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerAttachVolume", req, &resp) require.NotNil(err) - // Should recieve an error from the client endpoint - require.Contains(err.Error(), "must specify plugin name to dispense") + require.Contains(err.Error(), "no plugins registered for type") } func TestClientCSIController_AttachVolume_Forwarded(t *testing.T) { t.Parallel() require := require.New(t) - - // Start a server and client - s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 }) - defer cleanupS1() - s2, cleanupS2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 }) - defer cleanupS2() - TestJoin(t, s1, s2) - testutil.WaitForLeader(t, s1.RPC) - testutil.WaitForLeader(t, s2.RPC) - codec := rpcClient(t, s2) - - c, cleanupC := client.TestClient(t, func(c *config.Config) { - c.Servers = []string{s2.config.RPCAddr.String()} - c.GCDiskUsageThreshold = 100.0 - }) - defer cleanupC() - - waitForNodes(t, s2, 1) - - // Force remove the connection locally in case it exists - s1.nodeConnsLock.Lock() - delete(s1.nodeConns, c.NodeID()) - s1.nodeConnsLock.Unlock() + codec, cleanup := setupForward(t) + defer cleanup() req := &cstructs.ClientCSIControllerAttachVolumeRequest{ - CSIControllerQuery: cstructs.CSIControllerQuery{ControllerNodeID: c.NodeID()}, + CSIControllerQuery: cstructs.CSIControllerQuery{PluginID: "minnie"}, } - // Fetch the response var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerAttachVolume", req, &resp) require.NotNil(err) - // Should recieve an error from the client endpoint - require.Contains(err.Error(), "must specify plugin name to dispense") + require.Contains(err.Error(), "no plugins registered for type") } func TestClientCSIController_DetachVolume_Local(t *testing.T) { t.Parallel() require := require.New(t) + codec, cleanup := setupLocal(t) + defer cleanup() - // Start a server and client - s, cleanupS := TestServer(t, nil) - defer cleanupS() - codec := rpcClient(t, s) - testutil.WaitForLeader(t, s.RPC) + req := &cstructs.ClientCSIControllerDetachVolumeRequest{ + CSIControllerQuery: cstructs.CSIControllerQuery{PluginID: "minnie"}, + } - c, cleanupC := client.TestClient(t, func(c *config.Config) { - c.Servers = []string{s.config.RPCAddr.String()} - }) - defer cleanupC() + var resp structs.GenericResponse + err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "no plugins registered for type") +} - waitForNodes(t, s, 1) +func TestClientCSIController_DetachVolume_Forwarded(t *testing.T) { + t.Parallel() + require := require.New(t) + codec, cleanup := setupForward(t) + defer cleanup() req := &cstructs.ClientCSIControllerDetachVolumeRequest{ - CSIControllerQuery: cstructs.CSIControllerQuery{ControllerNodeID: c.NodeID()}, + CSIControllerQuery: cstructs.CSIControllerQuery{PluginID: "minnie"}, } - // Fetch the response var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp) require.NotNil(err) - // Should recieve an error from the client endpoint - require.Contains(err.Error(), "must specify plugin name to dispense") + require.Contains(err.Error(), "no plugins registered for type") } -func TestClientCSIController_DetachVolume_Forwarded(t *testing.T) { +func TestClientCSIController_ValidateVolume_Local(t *testing.T) { t.Parallel() require := require.New(t) + codec, cleanup := setupLocal(t) + defer cleanup() - // Start a server and client - s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 }) - defer cleanupS1() - s2, cleanupS2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 }) - defer cleanupS2() - TestJoin(t, s1, s2) - testutil.WaitForLeader(t, s1.RPC) - testutil.WaitForLeader(t, s2.RPC) - codec := rpcClient(t, s2) - - c, cleanupC := client.TestClient(t, func(c *config.Config) { - c.Servers = []string{s2.config.RPCAddr.String()} - c.GCDiskUsageThreshold = 100.0 - }) - defer cleanupC() + req := &cstructs.ClientCSIControllerValidateVolumeRequest{ + VolumeID: "test", + CSIControllerQuery: cstructs.CSIControllerQuery{PluginID: "minnie"}, + } - waitForNodes(t, s2, 1) + var resp structs.GenericResponse + err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerValidateVolume", req, &resp) + require.NotNil(err) + require.Contains(err.Error(), "no plugins registered for type") +} - // Force remove the connection locally in case it exists - s1.nodeConnsLock.Lock() - delete(s1.nodeConns, c.NodeID()) - s1.nodeConnsLock.Unlock() +func TestClientCSIController_ValidateVolume_Forwarded(t *testing.T) { + t.Parallel() + require := require.New(t) + codec, cleanup := setupForward(t) + defer cleanup() - req := &cstructs.ClientCSIControllerDetachVolumeRequest{ - CSIControllerQuery: cstructs.CSIControllerQuery{ControllerNodeID: c.NodeID()}, + req := &cstructs.ClientCSIControllerValidateVolumeRequest{ + VolumeID: "test", + CSIControllerQuery: cstructs.CSIControllerQuery{PluginID: "minnie"}, } - // Fetch the response var resp structs.GenericResponse - err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerDetachVolume", req, &resp) + err := msgpackrpc.CallWithCodec(codec, "ClientCSI.ControllerValidateVolume", req, &resp) require.NotNil(err) - // Should recieve an error from the client endpoint - require.Contains(err.Error(), "must specify plugin name to dispense") + require.Contains(err.Error(), "no plugins registered for type") } func TestClientCSI_NodeForControllerPlugin(t *testing.T) { @@ -188,22 +152,143 @@ func TestClientCSI_NodeForControllerPlugin(t *testing.T) { plugin, err := state.CSIPluginByID(ws, "minnie") require.NoError(t, err) - nodeID, err := srv.staticEndpoints.ClientCSI.nodeForController(plugin.ID, "") - + nodeIDs, err := srv.staticEndpoints.ClientCSI.clientIDsForController(plugin.ID) + require.NoError(t, err) + require.Equal(t, 1, len(nodeIDs)) // only node1 has both the controller and a recent Nomad version - require.Equal(t, nodeID, node1.ID) + require.Equal(t, nodeIDs[0], node1.ID) +} + +// sets up a pair of servers, each with one client, and registers a plugin to the clients. +// returns a RPC client to the leader and a cleanup function. +func setupForward(t *testing.T) (rpc.ClientCodec, func()) { + + s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 1 }) + + testutil.WaitForLeader(t, s1.RPC) + codec := rpcClient(t, s1) + + c1, cleanupC1 := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s1.config.RPCAddr.String()} + }) + + // Wait for client initialization + select { + case <-c1.Ready(): + case <-time.After(10 * time.Second): + cleanupS1() + cleanupC1() + t.Fatal("client timedout on initialize") + } + + waitForNodes(t, s1, 1, 1) + + s2, cleanupS2 := TestServer(t, func(c *Config) { c.BootstrapExpect = 2 }) + TestJoin(t, s1, s2) + + c2, cleanupC2 := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s2.config.RPCAddr.String()} + }) + select { + case <-c2.Ready(): + case <-time.After(10 * time.Second): + cleanupS1() + cleanupC1() + t.Fatal("client timedout on initialize") + } + + s1.nodeConnsLock.Lock() + delete(s1.nodeConns, c2.NodeID()) + s1.nodeConnsLock.Unlock() + + s2.nodeConnsLock.Lock() + delete(s2.nodeConns, c1.NodeID()) + s2.nodeConnsLock.Unlock() + + waitForNodes(t, s2, 1, 2) + + plugins := map[string]*structs.CSIInfo{ + "minnie": {PluginID: "minnie", + Healthy: true, + ControllerInfo: &structs.CSIControllerInfo{}, + NodeInfo: &structs.CSINodeInfo{}, + RequiresControllerPlugin: true, + }, + } + + // update w/ plugin + node1 := c1.Node() + node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions + node1.CSIControllerPlugins = plugins + + s1.fsm.state.UpsertNode(1000, node1) + + cleanup := func() { + cleanupS1() + cleanupC1() + cleanupS2() + cleanupC2() + } + + return codec, cleanup } -// waitForNodes waits until the server is connected to expectedNodes -// clients and they are in the state store -func waitForNodes(t *testing.T, s *Server, expectedNodes int) { +// sets up a single server with a client, and registers a plugin to the client. +func setupLocal(t *testing.T) (rpc.ClientCodec, func()) { + + s1, cleanupS1 := TestServer(t, func(c *Config) { c.BootstrapExpect = 1 }) + + testutil.WaitForLeader(t, s1.RPC) + codec := rpcClient(t, s1) + + c1, cleanupC1 := client.TestClient(t, func(c *config.Config) { + c.Servers = []string{s1.config.RPCAddr.String()} + }) + + // Wait for client initialization + select { + case <-c1.Ready(): + case <-time.After(10 * time.Second): + cleanupS1() + cleanupC1() + t.Fatal("client timedout on initialize") + } + + waitForNodes(t, s1, 1, 1) + + plugins := map[string]*structs.CSIInfo{ + "minnie": {PluginID: "minnie", + Healthy: true, + ControllerInfo: &structs.CSIControllerInfo{}, + NodeInfo: &structs.CSINodeInfo{}, + RequiresControllerPlugin: true, + }, + } + + // update w/ plugin + node1 := c1.Node() + node1.Attributes["nomad.version"] = "0.11.0" // client RPCs not supported on early versions + node1.CSIControllerPlugins = plugins + + s1.fsm.state.UpsertNode(1000, node1) + + cleanup := func() { + cleanupS1() + cleanupC1() + } + + return codec, cleanup +} + +// waitForNodes waits until the server is connected to connectedNodes +// clients and totalNodes clients are in the state store +func waitForNodes(t *testing.T, s *Server, connectedNodes, totalNodes int) { codec := rpcClient(t, s) testutil.WaitForResult(func() (bool, error) { connNodes := s.connectedNodes() - if len(connNodes) != expectedNodes { - return false, fmt.Errorf("expected %d nodes but found %d", expectedNodes, len(connNodes)) - + if len(connNodes) != connectedNodes { + return false, fmt.Errorf("expected %d connected nodes but found %d", connectedNodes, len(connNodes)) } get := &structs.NodeListRequest{ @@ -218,10 +303,9 @@ func waitForNodes(t *testing.T, s *Server, expectedNodes int) { if err != nil { return false, fmt.Errorf("failed to list nodes: %v", err) } - if len(resp.Nodes) != 1 { - return false, fmt.Errorf("expected %d nodes but found %d", 1, len(resp.Nodes)) + if len(resp.Nodes) != totalNodes { + return false, fmt.Errorf("expected %d total nodes but found %d", totalNodes, len(resp.Nodes)) } - return true, nil }, func(err error) { require.NoError(t, err)