Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

csi: retry controller client RPCs on next controller #8561

Merged
merged 1 commit into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 87 additions & 77 deletions nomad/client_csi_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package nomad
import (
"fmt"
"math/rand"
"strings"
"time"

metrics "github.com/armon/go-metrics"
Expand All @@ -20,74 +21,101 @@ type ClientCSI struct {

func (a *ClientCSI) ControllerAttachVolume(args *cstructs.ClientCSIControllerAttachVolumeRequest, reply *cstructs.ClientCSIControllerAttachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "attach_volume"}, time.Now())
// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)

clientIDs, err := a.clientIDsForController(args.PluginID)
if err != nil {
return err
return fmt.Errorf("controller attach volume: %v", err)
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
if !ok {
return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerAttachVolume", args, reply)
}
for _, clientID := range clientIDs {
args.ControllerNodeID = clientID
state, ok := a.srv.getNodeConn(clientID)
if !ok {
return findNodeConnAndForward(a.srv,
clientID, "ClientCSI.ControllerAttachVolume", args, reply)
}

// Make the RPC
err = NodeRpc(state.Session, "CSI.ControllerAttachVolume", args, reply)
if err != nil {
err = NodeRpc(state.Session, "CSI.ControllerAttachVolume", args, reply)
if err == nil {
return nil
}
if a.isRetryable(err, clientID, args.PluginID) {
a.logger.Debug("failed to reach controller on client %q: %v", clientID, err)
continue
}
return fmt.Errorf("controller attach volume: %v", err)
}
return nil
return fmt.Errorf("controller attach volume: %v", err)
}

func (a *ClientCSI) ControllerValidateVolume(args *cstructs.ClientCSIControllerValidateVolumeRequest, reply *cstructs.ClientCSIControllerValidateVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "validate_volume"}, time.Now())

// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
clientIDs, err := a.clientIDsForController(args.PluginID)
if err != nil {
return err
return fmt.Errorf("validate volume: %v", err)
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
if !ok {
return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerValidateVolume", args, reply)
}
for _, clientID := range clientIDs {
args.ControllerNodeID = clientID
state, ok := a.srv.getNodeConn(clientID)
if !ok {
return findNodeConnAndForward(a.srv,
clientID, "ClientCSI.ControllerValidateVolume", args, reply)
}

// Make the RPC
err = NodeRpc(state.Session, "CSI.ControllerValidateVolume", args, reply)
if err != nil {
err = NodeRpc(state.Session, "CSI.ControllerValidateVolume", args, reply)
langmartin marked this conversation as resolved.
Show resolved Hide resolved
if err == nil {
return nil
}
if a.isRetryable(err, clientID, args.PluginID) {
a.logger.Debug("failed to reach controller on client %q: %v", clientID, err)
continue
}
return fmt.Errorf("validate volume: %v", err)
}
return nil
return fmt.Errorf("validate volume: %v", err)
}

func (a *ClientCSI) ControllerDetachVolume(args *cstructs.ClientCSIControllerDetachVolumeRequest, reply *cstructs.ClientCSIControllerDetachVolumeResponse) error {
defer metrics.MeasureSince([]string{"nomad", "client_csi_controller", "detach_volume"}, time.Now())

// Get a Nomad client node for the controller
nodeID, err := a.nodeForController(args.PluginID, args.ControllerNodeID)
clientIDs, err := a.clientIDsForController(args.PluginID)
if err != nil {
return err
return fmt.Errorf("controller detach volume: %v", err)
}
args.ControllerNodeID = nodeID

// Get the connection to the client
state, ok := a.srv.getNodeConn(args.ControllerNodeID)
if !ok {
return findNodeConnAndForward(a.srv, args.ControllerNodeID, "ClientCSI.ControllerDetachVolume", args, reply)
}
for _, clientID := range clientIDs {
args.ControllerNodeID = clientID
state, ok := a.srv.getNodeConn(clientID)
if !ok {
return findNodeConnAndForward(a.srv,
clientID, "ClientCSI.ControllerDetachVolume", args, reply)
}

// Make the RPC
err = NodeRpc(state.Session, "CSI.ControllerDetachVolume", args, reply)
if err != nil {
err = NodeRpc(state.Session, "CSI.ControllerDetachVolume", args, reply)
if err == nil {
return nil
}
if a.isRetryable(err, clientID, args.PluginID) {
a.logger.Debug("failed to reach controller on client %q: %v", clientID, err)
continue
}
return fmt.Errorf("controller detach volume: %v", err)
}
return nil
return fmt.Errorf("controller detach volume: %v", err)
}

// we can retry the same RPC on a different controller in the cases where the
// client has stopped and been GC'd, or where the controller has stopped but
// we don't have the fingerprint update yet
func (a *ClientCSI) isRetryable(err error, clientID, pluginID string) bool {
// TODO(tgross): it would be nicer to use errors.Is here but we
// need to make sure we're using error wrapping to make that work
errMsg := err.Error()
return strings.Contains(errMsg, fmt.Sprintf("Unknown node: %s", clientID)) ||
strings.Contains(errMsg, "no plugins registered for type: csi-controller") ||
strings.Contains(errMsg, fmt.Sprintf("plugin %s for type controller not found", pluginID))
}

func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeRequest, reply *cstructs.ClientCSINodeDetachVolumeResponse) error {
Expand Down Expand Up @@ -119,29 +147,17 @@ func (a *ClientCSI) NodeDetachVolume(args *cstructs.ClientCSINodeDetachVolumeReq

}

// nodeForController validates that the Nomad client node ID for
// a plugin exists and is new enough to support client RPC. If no node
// ID is passed, select a random node ID for the controller to load-balance
// long blocking RPCs across client nodes.
func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) {
// clientIDsForController returns a shuffled list of client IDs where the
// controller plugin is expected to be running.
func (a *ClientCSI) clientIDsForController(pluginID string) ([]string, error) {

snap, err := a.srv.State().Snapshot()
if err != nil {
return "", err
}

if nodeID != "" {
_, err = getNodeForRpc(snap, nodeID)
if err == nil {
return nodeID, nil
} else {
// we'll fall-through and select a node at random
a.logger.Trace("could not be used for client RPC", "node", nodeID, "error", err)
}
return nil, err
}

if pluginID == "" {
return "", fmt.Errorf("missing plugin ID")
return nil, fmt.Errorf("missing plugin ID")
}

ws := memdb.NewWatchSet()
Expand All @@ -151,43 +167,37 @@ func (a *ClientCSI) nodeForController(pluginID, nodeID string) (string, error) {
// region/DC for the volume.
plugin, err := snap.CSIPluginByID(ws, pluginID)
if err != nil {
return "", fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
return nil, fmt.Errorf("error getting plugin: %s, %v", pluginID, err)
}
if plugin == nil {
return "", fmt.Errorf("plugin missing: %s %v", pluginID, err)
}
count := len(plugin.Controllers)
if count == 0 {
return "", fmt.Errorf("no controllers available for plugin %q", plugin.ID)
return nil, fmt.Errorf("plugin missing: %s", pluginID)
}

// iterating maps is "random" but unspecified and isn't particularly
// random with small maps, so not well-suited for load balancing.
// so we shuffle the keys and iterate over them.
clientIDs := make([]string, 0, count)
for clientID := range plugin.Controllers {
clientIDs = append(clientIDs, clientID)
}
rand.Shuffle(count, func(i, j int) {
clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
})
clientIDs := []string{}

for _, clientID := range clientIDs {
controller := plugin.Controllers[clientID]
for clientID, controller := range plugin.Controllers {
if !controller.IsController() {
// we don't have separate types for CSIInfo depending on
// whether it's a controller or node. this error shouldn't
// make it to production but is to aid developers during
// development
err = fmt.Errorf("plugin is not a controller")
continue
}
_, err = getNodeForRpc(snap, clientID)
if err != nil {
continue
node, err := getNodeForRpc(snap, clientID)
if err == nil && node != nil && node.Ready() {
clientIDs = append(clientIDs, clientID)
}
return clientID, nil
}
if len(clientIDs) == 0 {
return nil, fmt.Errorf("failed to find clients running controller plugin %q", pluginID)
}

rand.Shuffle(len(clientIDs), func(i, j int) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call

clientIDs[i], clientIDs[j] = clientIDs[j], clientIDs[i]
})

return "", err
return clientIDs, nil
}
Loading