Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[17.12] [manager/dispatcher] Synchronize Dispatcher.Stop() with incoming rpcs. #2522

Merged
merged 1 commit into from
Feb 21, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 62 additions & 17 deletions manager/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,17 @@ type clusterUpdate struct {

// Dispatcher is responsible for dispatching tasks and tracking agent health.
type Dispatcher struct {
mu sync.Mutex
wg sync.WaitGroup
// Mutex to synchronize access to dispatcher shared state e.g. nodes,
// lastSeenManagers, networkBootstrapKeys etc.
// TODO(anshul): This can potentially be removed and rpcRW used in its place.
mu sync.Mutex
// WaitGroup to handle the case when Stop() gets called before Run()
// has finished initializing the dispatcher.
wg sync.WaitGroup
// This RWMutex synchronizes RPC handlers and the dispatcher stop().
// The RPC handlers use the read lock while stop() uses the write lock
// and acts as a barrier to shutdown.
rpcRW sync.RWMutex
nodes *nodeStore
store *store.MemoryStore
lastSeenManagers []*api.WeightedPeer
Expand Down Expand Up @@ -312,7 +321,14 @@ func (d *Dispatcher) Stop() error {
}
d.cancel()
d.mu.Unlock()

// The active nodes list can be cleaned out only when all
// existing RPCs have finished.
// RPCs that start after rpcRW.Unlock() should find the context
// cancelled and should fail organically.
d.rpcRW.Lock()
d.nodes.Clean()
d.rpcRW.Unlock()

d.processUpdatesLock.Lock()
// In case there are any waiters. There is no chance of any starting
Expand All @@ -323,6 +339,11 @@ func (d *Dispatcher) Stop() error {

d.clusterUpdateQueue.Close()

// TODO(anshul): This use of Wait() could be unsafe.
// According to go's documentation on WaitGroup,
// Add() with a positive delta that occur when the counter is zero
// must happen before a Wait().
// As is, dispatcher Stop() can race with Run().
d.wg.Wait()

return nil
Expand Down Expand Up @@ -522,6 +543,14 @@ func (d *Dispatcher) register(ctx context.Context, nodeID string, description *a
// UpdateTaskStatus updates status of task. Node should send such updates
// on every status change of its tasks.
func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) {
d.rpcRW.RLock()
defer d.rpcRW.RUnlock()

dctx, err := d.isRunningLocked()
if err != nil {
return nil, err
}

nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return nil, err
Expand All @@ -537,11 +566,6 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat
}
log := log.G(ctx).WithFields(fields)

dctx, err := d.isRunningLocked()
if err != nil {
return nil, err
}

if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return nil, err
}
Expand Down Expand Up @@ -713,16 +737,19 @@ func (d *Dispatcher) processUpdates(ctx context.Context) {
// of tasks which should be run on node, if task is not present in that list,
// it should be terminated.
func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error {
nodeInfo, err := ca.RemoteNode(stream.Context())
d.rpcRW.RLock()
defer d.rpcRW.RUnlock()

dctx, err := d.isRunningLocked()
if err != nil {
return err
}
nodeID := nodeInfo.NodeID

dctx, err := d.isRunningLocked()
nodeInfo, err := ca.RemoteNode(stream.Context())
if err != nil {
return err
}
nodeID := nodeInfo.NodeID

fields := logrus.Fields{
"node.id": nodeID,
Expand Down Expand Up @@ -836,16 +863,19 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe
// Assignments is a stream of assignments for a node. Each message contains
// either full list of tasks and secrets for the node, or an incremental update.
func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error {
nodeInfo, err := ca.RemoteNode(stream.Context())
d.rpcRW.RLock()
defer d.rpcRW.RUnlock()

dctx, err := d.isRunningLocked()
if err != nil {
return err
}
nodeID := nodeInfo.NodeID

dctx, err := d.isRunningLocked()
nodeInfo, err := ca.RemoteNode(stream.Context())
if err != nil {
return err
}
nodeID := nodeInfo.NodeID

fields := logrus.Fields{
"node.id": nodeID,
Expand Down Expand Up @@ -1088,6 +1118,17 @@ func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, mes
// Node should send new heartbeat earlier than now + TTL, otherwise it will
// be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN
func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) {
d.rpcRW.RLock()
defer d.rpcRW.RUnlock()

// Its OK to call isRunning() here instead of isRunningLocked()
// because of the rpcRW readlock above.
// TODO(anshul) other uses of isRunningLocked() can probably
// also be removed.
if !d.isRunning() {
return nil, status.Errorf(codes.Aborted, "dispatcher is stopped")
}

nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1120,17 +1161,21 @@ func (d *Dispatcher) getRootCACert() []byte {
// a special boolean field Disconnect which if true indicates that node should
// reconnect to another Manager immediately.
func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error {
ctx := stream.Context()
nodeInfo, err := ca.RemoteNode(ctx)
d.rpcRW.RLock()
defer d.rpcRW.RUnlock()

dctx, err := d.isRunningLocked()
if err != nil {
return err
}
nodeID := nodeInfo.NodeID

dctx, err := d.isRunningLocked()
ctx := stream.Context()

nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return err
}
nodeID := nodeInfo.NodeID

var sessionID string
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
Expand Down