diff --git a/manager/dispatcher/dispatcher.go b/manager/dispatcher/dispatcher.go index 12c2a81e34..7b9e175438 100644 --- a/manager/dispatcher/dispatcher.go +++ b/manager/dispatcher/dispatcher.go @@ -125,8 +125,17 @@ type clusterUpdate struct { // Dispatcher is responsible for dispatching tasks and tracking agent health. type Dispatcher struct { - mu sync.Mutex - wg sync.WaitGroup + // Mutex to synchronize access to dispatcher shared state e.g. nodes, + // lastSeenManagers, networkBootstrapKeys etc. + // TODO(anshul): This can potentially be removed and rpcRW used in its place. + mu sync.Mutex + // WaitGroup to handle the case when Stop() gets called before Run() + // has finished initializing the dispatcher. + wg sync.WaitGroup + // This RWMutex synchronizes RPC handlers and the dispatcher stop(). + // The RPC handlers use the read lock while stop() uses the write lock + // and acts as a barrier to shutdown. + rpcRW sync.RWMutex nodes *nodeStore store *store.MemoryStore lastSeenManagers []*api.WeightedPeer @@ -318,7 +327,13 @@ func (d *Dispatcher) Stop() error { d.cancel() d.mu.Unlock() + // The active nodes list can be cleaned out only when all + // existing RPCs have finished. + // RPCs that start after rpcRW.Unlock() should find the context + // cancelled and should fail organically. + d.rpcRW.Lock() d.nodes.Clean() + d.rpcRW.Unlock() d.processUpdatesLock.Lock() // In case there are any waiters. There is no chance of any starting @@ -329,6 +344,11 @@ func (d *Dispatcher) Stop() error { d.clusterUpdateQueue.Close() + // TODO(anshul): This use of Wait() could be unsafe. + // According to go's documentation on WaitGroup, + // Add() with a positive delta that occur when the counter is zero + // must happen before a Wait(). + // As is, dispatcher Stop() can race with Run(). d.wg.Wait() return nil @@ -532,6 +552,14 @@ func (d *Dispatcher) register(ctx context.Context, nodeID string, description *a // UpdateTaskStatus updates status of task. Node should send such updates // on every status change of its tasks. func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) { + d.rpcRW.RLock() + defer d.rpcRW.RUnlock() + + dctx, err := d.isRunningLocked() + if err != nil { + return nil, err + } + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -547,11 +575,6 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat } log := log.G(ctx).WithFields(fields) - dctx, err := d.isRunningLocked() - if err != nil { - return nil, err - } - if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return nil, err } @@ -723,16 +746,19 @@ func (d *Dispatcher) processUpdates(ctx context.Context) { // of tasks which should be run on node, if task is not present in that list, // it should be terminated. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error { - nodeInfo, err := ca.RemoteNode(stream.Context()) + d.rpcRW.RLock() + defer d.rpcRW.RUnlock() + + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID - dctx, err := d.isRunningLocked() + nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } + nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, @@ -846,16 +872,19 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe // Assignments is a stream of assignments for a node. Each message contains // either full list of tasks and secrets for the node, or an incremental update. func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error { - nodeInfo, err := ca.RemoteNode(stream.Context()) + d.rpcRW.RLock() + defer d.rpcRW.RUnlock() + + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID - dctx, err := d.isRunningLocked() + nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } + nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, @@ -1103,6 +1132,17 @@ func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, mes // Node should send new heartbeat earlier than now + TTL, otherwise it will // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) { + d.rpcRW.RLock() + defer d.rpcRW.RUnlock() + + // Its OK to call isRunning() here instead of isRunningLocked() + // because of the rpcRW readlock above. + // TODO(anshul) other uses of isRunningLocked() can probably + // also be removed. + if !d.isRunning() { + return nil, status.Errorf(codes.Aborted, "dispatcher is stopped") + } + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -1137,17 +1177,21 @@ func (d *Dispatcher) getRootCACert() []byte { // a special boolean field Disconnect which if true indicates that node should // reconnect to another Manager immediately. func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error { - ctx := stream.Context() - nodeInfo, err := ca.RemoteNode(ctx) + d.rpcRW.RLock() + defer d.rpcRW.RUnlock() + + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID - dctx, err := d.isRunningLocked() + ctx := stream.Context() + + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return err } + nodeID := nodeInfo.NodeID var sessionID string if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {