diff --git a/manager/dispatcher/dispatcher.go b/manager/dispatcher/dispatcher.go index 12c2a81e34..bf48818ea5 100644 --- a/manager/dispatcher/dispatcher.go +++ b/manager/dispatcher/dispatcher.go @@ -125,8 +125,12 @@ type clusterUpdate struct { // Dispatcher is responsible for dispatching tasks and tracking agent health. type Dispatcher struct { - mu sync.Mutex - wg sync.WaitGroup + // mu is a lock to provide mutually exclusive access to dispatcher fields + // e.g. lastSeenManagers, networkBootstrapKeys, lastSeenRootCert etc. + mu sync.Mutex + // shutdownWait is used by stop() to wait for existing operations to finish. + shutdownWait sync.WaitGroup + nodes *nodeStore store *store.MemoryStore lastSeenManagers []*api.WeightedPeer @@ -249,8 +253,11 @@ func (d *Dispatcher) Run(ctx context.Context) error { defer cancel() d.ctx, d.cancel = context.WithCancel(ctx) ctx = d.ctx - d.wg.Add(1) - defer d.wg.Done() + + // If Stop() is called, it should wait + // for Run() to complete. + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() d.mu.Unlock() publishManagers := func(peers []*api.Peer) { @@ -313,11 +320,14 @@ func (d *Dispatcher) Stop() error { return errors.New("dispatcher is already stopped") } - log := log.G(d.ctx).WithField("method", "(*Dispatcher).Stop") - log.Info("dispatcher stopping") + // Cancel dispatcher context. + // This should also close the the streams in Tasks(), Assignments(). d.cancel() d.mu.Unlock() + // Wait for the RPCs that are in-progress to finish. + d.shutdownWait.Wait() + d.nodes.Clean() d.processUpdatesLock.Lock() @@ -328,9 +338,6 @@ func (d *Dispatcher) Stop() error { d.processUpdatesLock.Unlock() d.clusterUpdateQueue.Close() - - d.wg.Wait() - return nil } @@ -478,13 +485,13 @@ func nodeIPFromContext(ctx context.Context) (string, error) { // register is used for registration of node with particular dispatcher. func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) { - logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register") - // prevent register until we're ready to accept it dctx, err := d.isRunningLocked() if err != nil { return "", err } + logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register") + if err := d.nodes.CheckRateLimit(nodeID); err != nil { return "", err } @@ -532,6 +539,21 @@ func (d *Dispatcher) register(ctx context.Context, nodeID string, description *a // UpdateTaskStatus updates status of task. Node should send such updates // on every status change of its tasks. func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) { + // shutdownWait.Add() followed by isRunning() to ensures that + // if this rpc sees the dispatcher running, + // it will already have called Add() on the shutdownWait wait, + // which ensures that Stop() will wait for this rpc to complete. + // Note that Stop() first does Dispatcher.ctx.cancel() followed by + // shutdownWait.Wait() to make sure new rpc's don't start before waiting + // for existing ones to finish. + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + + dctx, err := d.isRunningLocked() + if err != nil { + return nil, err + } + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -547,11 +569,6 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat } log := log.G(ctx).WithFields(fields) - dctx, err := d.isRunningLocked() - if err != nil { - return nil, err - } - if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil { return nil, err } @@ -723,16 +740,26 @@ func (d *Dispatcher) processUpdates(ctx context.Context) { // of tasks which should be run on node, if task is not present in that list, // it should be terminated. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error { - nodeInfo, err := ca.RemoteNode(stream.Context()) + // shutdownWait.Add() followed by isRunning() to ensures that + // if this rpc sees the dispatcher running, + // it will already have called Add() on the shutdownWait wait, + // which ensures that Stop() will wait for this rpc to complete. + // Note that Stop() first does Dispatcher.ctx.cancel() followed by + // shutdownWait.Wait() to make sure new rpc's don't start before waiting + // for existing ones to finish. + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID - dctx, err := d.isRunningLocked() + nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } + nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, @@ -846,16 +873,26 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe // Assignments is a stream of assignments for a node. Each message contains // either full list of tasks and secrets for the node, or an incremental update. func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error { - nodeInfo, err := ca.RemoteNode(stream.Context()) + // shutdownWait.Add() followed by isRunning() to ensures that + // if this rpc sees the dispatcher running, + // it will already have called Add() on the shutdownWait wait, + // which ensures that Stop() will wait for this rpc to complete. + // Note that Stop() first does Dispatcher.ctx.cancel() followed by + // shutdownWait.Wait() to make sure new rpc's don't start before waiting + // for existing ones to finish. + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID - dctx, err := d.isRunningLocked() + nodeInfo, err := ca.RemoteNode(stream.Context()) if err != nil { return err } + nodeID := nodeInfo.NodeID fields := logrus.Fields{ "node.id": nodeID, @@ -1103,6 +1140,24 @@ func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, mes // Node should send new heartbeat earlier than now + TTL, otherwise it will // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) { + // shutdownWait.Add() followed by isRunning() to ensures that + // if this rpc sees the dispatcher running, + // it will already have called Add() on the shutdownWait wait, + // which ensures that Stop() will wait for this rpc to complete. + // Note that Stop() first does Dispatcher.ctx.cancel() followed by + // shutdownWait.Wait() to make sure new rpc's don't start before waiting + // for existing ones to finish. + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + + // isRunningLocked() is not needed since its OK if + // the dispatcher context is cancelled while this call is in progress + // since Stop() which cancels the dispatcher context will wait for + // Heartbeat() to complete. + if !d.isRunning() { + return nil, status.Errorf(codes.Aborted, "dispatcher is stopped") + } + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return nil, err @@ -1137,17 +1192,27 @@ func (d *Dispatcher) getRootCACert() []byte { // a special boolean field Disconnect which if true indicates that node should // reconnect to another Manager immediately. func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error { - ctx := stream.Context() - nodeInfo, err := ca.RemoteNode(ctx) + // shutdownWait.Add() followed by isRunning() to ensures that + // if this rpc sees the dispatcher running, + // it will already have called Add() on the shutdownWait wait, + // which ensures that Stop() will wait for this rpc to complete. + // Note that Stop() first does Dispatcher.ctx.cancel() followed by + // shutdownWait.Wait() to make sure new rpc's don't start before waiting + // for existing ones to finish. + d.shutdownWait.Add(1) + defer d.shutdownWait.Done() + + dctx, err := d.isRunningLocked() if err != nil { return err } - nodeID := nodeInfo.NodeID - dctx, err := d.isRunningLocked() + ctx := stream.Context() + nodeInfo, err := ca.RemoteNode(ctx) if err != nil { return err } + nodeID := nodeInfo.NodeID var sessionID string if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {