diff --git a/manager/dispatcher/dispatcher.go b/manager/dispatcher/dispatcher.go index bf48818ea5..12c2a81e34 100644 --- a/manager/dispatcher/dispatcher.go +++ b/manager/dispatcher/dispatcher.go @@ -125,12 +125,8 @@ type clusterUpdate struct { // Dispatcher is responsible for dispatching tasks and tracking agent health. type Dispatcher struct { - // mu is a lock to provide mutually exclusive access to dispatcher fields - // e.g. lastSeenManagers, networkBootstrapKeys, lastSeenRootCert etc. - mu sync.Mutex - // shutdownWait is used by stop() to wait for existing operations to finish. - shutdownWait sync.WaitGroup - + mu sync.Mutex + wg sync.WaitGroup nodes *nodeStore store *store.MemoryStore lastSeenManagers []*api.WeightedPeer @@ -253,11 +249,8 @@ func (d *Dispatcher) Run(ctx context.Context) error { defer cancel() d.ctx, d.cancel = context.WithCancel(ctx) ctx = d.ctx - - // If Stop() is called, it should wait - // for Run() to complete. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() + d.wg.Add(1) + defer d.wg.Done() d.mu.Unlock() publishManagers := func(peers []*api.Peer) { @@ -320,14 +313,11 @@ func (d *Dispatcher) Stop() error { return errors.New("dispatcher is already stopped") } - // Cancel dispatcher context. - // This should also close the the streams in Tasks(), Assignments(). + log := log.G(d.ctx).WithField("method", "(*Dispatcher).Stop") + log.Info("dispatcher stopping") d.cancel() d.mu.Unlock() - // Wait for the RPCs that are in-progress to finish. - d.shutdownWait.Wait() - d.nodes.Clean() d.processUpdatesLock.Lock() @@ -338,6 +328,9 @@ func (d *Dispatcher) Stop() error { d.processUpdatesLock.Unlock() d.clusterUpdateQueue.Close() + + d.wg.Wait() + return nil } @@ -485,13 +478,13 @@ func nodeIPFromContext(ctx context.Context) (string, error) { // register is used for registration of node with particular dispatcher. func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) { + logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register") + // prevent register until we're ready to accept it dctx, err := d.isRunningLocked() if err != nil { return "", err } - logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register") - if err := d.nodes.CheckRateLimit(nodeID); err != nil { return "", err } @@ -539,21 +532,6 @@ func (d *Dispatcher) register(ctx context.Context, nodeID string, description *a // UpdateTaskStatus updates status of task. Node should send such updates // on every status change of its tasks. func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - dctx, err := d.isRunningLocked() - if err != nil { - return nil, err - } - nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -569,6 +547,11 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat } log := log.G(ctx).WithFields(fields) + dctx, err := d.isRunningLocked() + if err != nil { + return nil, err + } + if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return nil, err } @@ -740,26 +723,16 @@ func (d *Dispatcher) processUpdates(ctx context.Context) { // of tasks which should be run on node, if task is not present in that list, // it should be terminated. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - dctx, err := d.isRunningLocked() + nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } + nodeID := nodeInfo.NodeID - nodeInfo, err := ca.RemoteNode(stream.Context()) + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, @@ -873,26 +846,16 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe // Assignments is a stream of assignments for a node. Each message contains // either full list of tasks and secrets for the node, or an incremental update. func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - dctx, err := d.isRunningLocked() + nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } + nodeID := nodeInfo.NodeID - nodeInfo, err := ca.RemoteNode(stream.Context()) + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, @@ -1140,24 +1103,6 @@ func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, mes // Node should send new heartbeat earlier than now + TTL, otherwise it will // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - // isRunningLocked() is not needed since its OK if - // the dispatcher context is cancelled while this call is in progress - // since Stop() which cancels the dispatcher context will wait for - // Heartbeat() to complete. - if !d.isRunning() { - return nil, status.Errorf(codes.Aborted, "dispatcher is stopped") - } - nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -1192,27 +1137,17 @@ func (d *Dispatcher) getRootCACert() []byte { // a special boolean field Disconnect which if true indicates that node should // reconnect to another Manager immediately. func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error { - // shutdownWait.Add() followed by isRunning() to ensures that - // if this rpc sees the dispatcher running, - // it will already have called Add() on the shutdownWait wait, - // which ensures that Stop() will wait for this rpc to complete. - // Note that Stop() first does Dispatcher.ctx.cancel() followed by - // shutdownWait.Wait() to make sure new rpc's don't start before waiting - // for existing ones to finish. - d.shutdownWait.Add(1) - defer d.shutdownWait.Done() - - dctx, err := d.isRunningLocked() + ctx := stream.Context() + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return err } + nodeID := nodeInfo.NodeID - ctx := stream.Context() - nodeInfo, err := ca.RemoteNode(ctx) + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID var sessionID string if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {