diff --git a/client/consul.go b/client/consul.go index 89666e41e451..02e40ef0f09f 100644 --- a/client/consul.go +++ b/client/consul.go @@ -10,8 +10,8 @@ import ( // ConsulServiceAPI is the interface the Nomad Client uses to register and // remove services and checks from Consul. type ConsulServiceAPI interface { - RegisterTask(allocID string, task *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error + RegisterTask(allocID string, task *structs.Task, restarter consul.TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error RemoveTask(allocID string, task *structs.Task) - UpdateTask(allocID string, existing, newTask *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error + UpdateTask(allocID string, existing, newTask *structs.Task, restart consul.TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error AllocRegistrations(allocID string) (*consul.AllocRegistration, error) } diff --git a/client/task_runner.go b/client/task_runner.go index cd5afbd91975..36326448991b 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -141,6 +141,10 @@ type TaskRunner struct { // restartCh is used to restart a task restartCh chan *structs.TaskEvent + // lastStart tracks the last time this task was started or restarted + lastStart time.Time + lastStartMu sync.Mutex + // signalCh is used to send a signal to a task signalCh chan SignalEvent @@ -1362,6 +1366,11 @@ func (r *TaskRunner) killTask(killingEvent *structs.TaskEvent) { // startTask creates the driver, task dir, and starts the task. func (r *TaskRunner) startTask() error { + // Update lastStart + r.lastStartMu.Lock() + r.lastStart = time.Now() + r.lastStartMu.Unlock() + // Create a driver drv, err := r.createDriver() if err != nil { @@ -1439,7 +1448,7 @@ func (r *TaskRunner) registerServices(d driver.Driver, h driver.DriverHandle, n exec = h } interpolatedTask := interpolateServices(r.envBuilder.Build(), r.task) - return r.consul.RegisterTask(r.alloc.ID, interpolatedTask, exec, n) + return r.consul.RegisterTask(r.alloc.ID, interpolatedTask, r, exec, n) } // interpolateServices interpolates tags in a service and checks with values from the @@ -1641,7 +1650,7 @@ func (r *TaskRunner) updateServices(d driver.Driver, h driver.ScriptExecutor, ol r.driverNetLock.Lock() net := r.driverNet.Copy() r.driverNetLock.Unlock() - return r.consul.UpdateTask(r.alloc.ID, oldInterpolatedTask, newInterpolatedTask, exec, net) + return r.consul.UpdateTask(r.alloc.ID, oldInterpolatedTask, newInterpolatedTask, r, exec, net) } // handleDestroy kills the task handle. In the case that killing fails, @@ -1671,6 +1680,16 @@ func (r *TaskRunner) handleDestroy(handle driver.DriverHandle) (destroyed bool, // Restart will restart the task func (r *TaskRunner) Restart(source, reason string) { + r.lastStartMu.Lock() + defer r.lastStartMu.Unlock() + + r.restart(source, reason) +} + +// restart is the internal task restart method. Callers must hold lastStartMu. +func (r *TaskRunner) restart(source, reason string) { + r.lastStart = time.Now() + reasonStr := fmt.Sprintf("%s: %s", source, reason) event := structs.NewTaskEvent(structs.TaskRestartSignal).SetRestartReason(reasonStr) @@ -1680,6 +1699,25 @@ func (r *TaskRunner) Restart(source, reason string) { } } +// RestartBy deadline. Restarts a task iff the last time it was started was +// before the deadline. Returns true if restart occurs; false if skipped. +func (r *TaskRunner) RestartBy(deadline time.Time, source, reason string) { + r.lastStartMu.Lock() + defer r.lastStartMu.Unlock() + + if r.lastStart.Before(deadline) { + r.restart(source, reason) + } +} + +// LastStart returns the last time this task was started (including restarts). +func (r *TaskRunner) LastStart() time.Time { + r.lastStartMu.Lock() + ls := r.lastStart + r.lastStartMu.Unlock() + return ls +} + // Signal will send a signal to the task func (r *TaskRunner) Signal(source, reason string, s os.Signal) error { diff --git a/command/agent/consul/check_watcher.go b/command/agent/consul/check_watcher.go new file mode 100644 index 000000000000..46967ef5f300 --- /dev/null +++ b/command/agent/consul/check_watcher.go @@ -0,0 +1,244 @@ +package consul + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/hashicorp/consul/api" + "github.com/hashicorp/nomad/nomad/structs" +) + +const ( + // defaultPollFreq is the default rate to poll the Consul Checks API + defaultPollFreq = 900 * time.Millisecond +) + +type ConsulChecks interface { + Checks() (map[string]*api.AgentCheck, error) +} + +type TaskRestarter interface { + LastStart() time.Time + RestartBy(deadline time.Time, source, reason string) +} + +// checkRestart handles restarting a task if a check is unhealthy. +type checkRestart struct { + allocID string + taskName string + checkID string + checkName string + + // remove this checkID (if true only checkID will be set) + remove bool + + task TaskRestarter + grace time.Duration + interval time.Duration + timeLimit time.Duration + warning bool + + // unhealthyStart is the time a check first went unhealthy. Set to the + // zero value if the check passes before timeLimit. + // This is the only mutable field on checkRestart. + unhealthyStart time.Time + + logger *log.Logger +} + +// update restart state for check and restart task if necessary. Currrent +// timestamp is passed in so all check updates have the same view of time (and +// to ease testing). +func (c *checkRestart) update(now time.Time, status string) { + switch status { + case api.HealthCritical: + case api.HealthWarning: + if !c.warning { + // Warnings are ok, reset state and exit + c.unhealthyStart = time.Time{} + return + } + default: + // All other statuses are ok, reset state and exit + c.unhealthyStart = time.Time{} + return + } + + if now.Before(c.task.LastStart().Add(c.grace)) { + // In grace period, reset state and exit + c.unhealthyStart = time.Time{} + return + } + + if c.unhealthyStart.IsZero() { + // First failure, set restart deadline + c.unhealthyStart = now + } + + // restart timeLimit after start of this check becoming unhealthy + restartAt := c.unhealthyStart.Add(c.timeLimit) + + // Must test >= because if limit=1, restartAt == first failure + if now.UnixNano() >= restartAt.UnixNano() { + // hasn't become healthy by deadline, restart! + c.logger.Printf("[DEBUG] consul.health: restarting alloc %q task %q due to unhealthy check %q", c.allocID, c.taskName, c.checkName) + c.task.RestartBy(now, "healthcheck", fmt.Sprintf("check %q unhealthy", c.checkName)) + } +} + +// checkWatcher watches Consul checks and restarts tasks when they're +// unhealthy. +type checkWatcher struct { + consul ConsulChecks + + pollFreq time.Duration + + watchCh chan *checkRestart + + // done is closed when Run has exited + done chan struct{} + + // lastErr is true if the last Consul call failed. It is used to + // squelch repeated error messages. + lastErr bool + + logger *log.Logger +} + +// newCheckWatcher creates a new checkWatcher but does not call its Run method. +func newCheckWatcher(logger *log.Logger, consul ConsulChecks) *checkWatcher { + return &checkWatcher{ + consul: consul, + pollFreq: defaultPollFreq, + watchCh: make(chan *checkRestart, 8), + done: make(chan struct{}), + logger: logger, + } +} + +// Run the main Consul checks watching loop to restart tasks when their checks +// fail. Blocks until context is canceled. +func (w *checkWatcher) Run(ctx context.Context) { + defer close(w.done) + + // map of check IDs to their metadata + checks := map[string]*checkRestart{} + + // timer for check polling + checkTimer := time.NewTimer(0) + defer checkTimer.Stop() // ensure timer is never leaked + resetTimer := func(d time.Duration) { + if !checkTimer.Stop() { + <-checkTimer.C + } + checkTimer.Reset(d) + } + + // Main watch loop + for { + // Don't start watching until we actually have checks that + // trigger restarts. + for len(checks) == 0 { + select { + case c := <-w.watchCh: + if c.remove { + // should not happen + w.logger.Printf("[DEBUG] consul.health: told to stop watching an unwatched check: %q", c.checkID) + } else { + checks[c.checkID] = c + + // First check should be after grace period + resetTimer(c.grace) + } + case <-ctx.Done(): + return + } + } + + // As long as there are checks to be watched, keep watching + for len(checks) > 0 { + select { + case c := <-w.watchCh: + if c.remove { + delete(checks, c.checkID) + } else { + checks[c.checkID] = c + w.logger.Printf("[DEBUG] consul.health: watching alloc %q task %q check %q", c.allocID, c.taskName, c.checkName) + } + case <-ctx.Done(): + return + case <-checkTimer.C: + checkTimer.Reset(w.pollFreq) + + // Set "now" as the point in time the following check results represent + now := time.Now() + + results, err := w.consul.Checks() + if err != nil { + if !w.lastErr { + w.lastErr = true + w.logger.Printf("[ERR] consul.health: error retrieving health checks: %q", err) + } + continue + } + + w.lastErr = false + + // Loop over watched checks and update their status from results + for cid, check := range checks { + result, ok := results[cid] + if !ok { + w.logger.Printf("[WARN] consul.health: watched check %q (%s) not found in Consul", check.checkName, cid) + continue + } + + check.update(now, result.Status) + } + } + } + } +} + +// Watch a task and restart it if unhealthy. +func (w *checkWatcher) Watch(allocID, taskName, checkID string, check *structs.ServiceCheck, restarter TaskRestarter) { + if !check.Watched() { + // Not watched, noop + return + } + + c := checkRestart{ + allocID: allocID, + taskName: taskName, + checkID: checkID, + checkName: check.Name, + task: restarter, + interval: check.Interval, + grace: check.CheckRestart.Grace, + timeLimit: check.Interval * time.Duration(check.CheckRestart.Limit-1), + warning: check.CheckRestart.OnWarning, + logger: w.logger, + } + + select { + case w.watchCh <- &c: + // sent watch + case <-w.done: + // exited; nothing to do + } +} + +// Unwatch a task. +func (w *checkWatcher) Unwatch(cid string) { + c := checkRestart{ + checkID: cid, + remove: true, + } + select { + case w.watchCh <- &c: + // sent remove watch + case <-w.done: + // exited; nothing to do + } +} diff --git a/command/agent/consul/client.go b/command/agent/consul/client.go index 8285785fbde0..911ba628e569 100644 --- a/command/agent/consul/client.go +++ b/command/agent/consul/client.go @@ -1,6 +1,7 @@ package consul import ( + "context" "fmt" "log" "net" @@ -223,6 +224,9 @@ type ServiceClient struct { // seen is 1 if Consul has ever been seen; otherise 0. Accessed with // atomics. seen int32 + + // checkWatcher restarts checks that are unhealthy. + checkWatcher *checkWatcher } // NewServiceClient creates a new Consul ServiceClient from an existing Consul API @@ -245,6 +249,7 @@ func NewServiceClient(consulClient AgentAPI, skipVerifySupport bool, logger *log allocRegistrations: make(map[string]*AllocRegistration), agentServices: make(map[string]struct{}), agentChecks: make(map[string]struct{}), + checkWatcher: newCheckWatcher(logger, consulClient), } } @@ -267,6 +272,12 @@ func (c *ServiceClient) hasSeen() bool { // be called exactly once. func (c *ServiceClient) Run() { defer close(c.exitCh) + + // start checkWatcher + ctx, cancelWatcher := context.WithCancel(context.Background()) + defer cancelWatcher() + go c.checkWatcher.Run(ctx) + retryTimer := time.NewTimer(0) <-retryTimer.C // disabled by default failures := 0 @@ -274,6 +285,7 @@ func (c *ServiceClient) Run() { select { case <-retryTimer.C: case <-c.shutdownCh: + cancelWatcher() case ops := <-c.opCh: c.merge(ops) } @@ -656,7 +668,7 @@ func (c *ServiceClient) checkRegs(ops *operations, allocID, serviceID string, se // Checks will always use the IP from the Task struct (host's IP). // // Actual communication with Consul is done asynchrously (see Run). -func (c *ServiceClient) RegisterTask(allocID string, task *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { +func (c *ServiceClient) RegisterTask(allocID string, task *structs.Task, restarter TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { // Fast path numServices := len(task.Services) if numServices == 0 { @@ -679,6 +691,18 @@ func (c *ServiceClient) RegisterTask(allocID string, task *structs.Task, exec dr c.addTaskRegistration(allocID, task.Name, t) c.commit(ops) + + // Start watching checks. Done after service registrations are built + // since an error building them could leak watches. + for _, service := range task.Services { + serviceID := makeTaskServiceID(allocID, task.Name, service) + for _, check := range service.Checks { + if check.Watched() { + checkID := makeCheckID(serviceID, check) + c.checkWatcher.Watch(allocID, task.Name, checkID, check, restarter) + } + } + } return nil } @@ -686,7 +710,7 @@ func (c *ServiceClient) RegisterTask(allocID string, task *structs.Task, exec dr // changed. // // DriverNetwork must not change between invocations for the same allocation. -func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Task, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { +func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Task, restarter TaskRestarter, exec driver.ScriptExecutor, net *cstructs.DriverNetwork) error { ops := &operations{} t := new(TaskRegistration) @@ -709,7 +733,13 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta // Existing service entry removed ops.deregServices = append(ops.deregServices, existingID) for _, check := range existingSvc.Checks { - ops.deregChecks = append(ops.deregChecks, makeCheckID(existingID, check)) + cid := makeCheckID(existingID, check) + ops.deregChecks = append(ops.deregChecks, cid) + + // Unwatch watched checks + if check.Watched() { + c.checkWatcher.Unwatch(cid) + } } continue } @@ -730,9 +760,9 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta } // Check to see what checks were updated - existingChecks := make(map[string]struct{}, len(existingSvc.Checks)) + existingChecks := make(map[string]*structs.ServiceCheck, len(existingSvc.Checks)) for _, check := range existingSvc.Checks { - existingChecks[makeCheckID(existingID, check)] = struct{}{} + existingChecks[makeCheckID(existingID, check)] = check } // Register new checks @@ -748,15 +778,28 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta if err != nil { return err } + for _, checkID := range newCheckIDs { sreg.checkIDs[checkID] = struct{}{} + } + + } + + // Update all watched checks as CheckRestart fields aren't part of ID + if check.Watched() { + c.checkWatcher.Watch(allocID, newTask.Name, checkID, check, restarter) } } // Remove existing checks not in updated service - for cid := range existingChecks { + for cid, check := range existingChecks { ops.deregChecks = append(ops.deregChecks, cid) + + // Unwatch checks + if check.Watched() { + c.checkWatcher.Unwatch(cid) + } } } @@ -774,6 +817,18 @@ func (c *ServiceClient) UpdateTask(allocID string, existing, newTask *structs.Ta c.addTaskRegistration(allocID, newTask.Name, t) c.commit(ops) + + // Start watching checks. Done after service registrations are built + // since an error building them could leak watches. + for _, service := range newIDs { + serviceID := makeTaskServiceID(allocID, newTask.Name, service) + for _, check := range service.Checks { + if check.Watched() { + checkID := makeCheckID(serviceID, check) + c.checkWatcher.Watch(allocID, newTask.Name, checkID, check, restarter) + } + } + } return nil } @@ -788,7 +843,12 @@ func (c *ServiceClient) RemoveTask(allocID string, task *structs.Task) { ops.deregServices = append(ops.deregServices, id) for _, check := range service.Checks { - ops.deregChecks = append(ops.deregChecks, makeCheckID(id, check)) + cid := makeCheckID(id, check) + ops.deregChecks = append(ops.deregChecks, cid) + + if check.Watched() { + c.checkWatcher.Unwatch(cid) + } } } diff --git a/jobspec/parse.go b/jobspec/parse.go index 9611578703ef..0058fb2deac5 100644 --- a/jobspec/parse.go +++ b/jobspec/parse.go @@ -921,6 +921,7 @@ func parseServices(jobName string, taskGroupName string, task *api.Task, service } delete(m, "check") + delete(m, "check_restart") if err := mapstructure.WeakDecode(m, &service); err != nil { return err @@ -940,6 +941,18 @@ func parseServices(jobName string, taskGroupName string, task *api.Task, service } } + // Filter check_restart + if cro := checkList.Filter("check_restart"); len(cro.Items) > 0 { + if len(cro.Items) > 1 { + return fmt.Errorf("check_restart '%s': cannot have more than 1 check_restart", service.Name) + } + if cr, err := parseCheckRestart(cro.Items[0]); err != nil { + return multierror.Prefix(err, fmt.Sprintf("service: '%s',", service.Name)) + } else { + service.CheckRestart = cr + } + } + task.Services[idx] = &service } @@ -964,9 +977,7 @@ func parseChecks(service *api.Service, checkObjs *ast.ObjectList) error { "tls_skip_verify", "header", "method", - "restart_grace_period", - "restart_on_warning", - "restart_after_unhealthy", + "check_restart", } if err := checkHCLKeys(co.Val, valid); err != nil { return multierror.Prefix(err, "check ->") @@ -1008,6 +1019,8 @@ func parseChecks(service *api.Service, checkObjs *ast.ObjectList) error { delete(cm, "header") } + delete(cm, "check_restart") + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.StringToTimeDurationHookFunc(), WeaklyTypedInput: true, @@ -1020,12 +1033,63 @@ func parseChecks(service *api.Service, checkObjs *ast.ObjectList) error { return err } + // Filter check_restart + var checkRestartList *ast.ObjectList + if ot, ok := co.Val.(*ast.ObjectType); ok { + checkRestartList = ot.List + } else { + return fmt.Errorf("check_restart '%s': should be an object", check.Name) + } + + if cro := checkRestartList.Filter("check_restart"); len(cro.Items) > 0 { + if len(cro.Items) > 1 { + return fmt.Errorf("check_restart '%s': cannot have more than 1 check_restart", check.Name) + } + if cr, err := parseCheckRestart(cro.Items[0]); err != nil { + return multierror.Prefix(err, fmt.Sprintf("check: '%s',", check.Name)) + } else { + check.CheckRestart = cr + } + } + service.Checks[idx] = check } return nil } +func parseCheckRestart(cro *ast.ObjectItem) (*api.CheckRestart, error) { + valid := []string{ + "limit", + "grace_period", + "on_warning", + } + + if err := checkHCLKeys(cro.Val, valid); err != nil { + return nil, multierror.Prefix(err, "check_restart ->") + } + + var checkRestart api.CheckRestart + var crm map[string]interface{} + if err := hcl.DecodeObject(&crm, cro.Val); err != nil { + return nil, err + } + + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + WeaklyTypedInput: true, + Result: &checkRestart, + }) + if err != nil { + return nil, err + } + if err := dec.Decode(crm); err != nil { + return nil, err + } + + return &checkRestart, nil +} + func parseResources(result *api.Resources, list *ast.ObjectList) error { list = list.Elem() if len(list.Items) == 0 { diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 01f6923509f9..b0fdaff04892 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -2874,6 +2874,12 @@ func (sc *ServiceCheck) RequiresPort() bool { } } +// Watched returns true if this check should be watched and trigger a restart +// on failure. +func (sc *ServiceCheck) Watched() bool { + return sc.CheckRestart != nil && sc.CheckRestart.Limit > 0 +} + // Hash all ServiceCheck fields and the check's corresponding service ID to // create an identifier. The identifier is not guaranteed to be unique as if // the PortLabel is blank, the Service's PortLabel will be used after Hash is