Skip to content

Commit

Permalink
lifecycle: add poststop hook
Browse files Browse the repository at this point in the history
  • Loading branch information
jazzyfresh committed Nov 11, 2020
1 parent 306cfab commit 2eca731
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 15 deletions.
1 change: 1 addition & 0 deletions api/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ type DispatchPayloadConfig struct {
const (
TaskLifecycleHookPrestart = "prestart"
TaskLifecycleHookPoststart = "poststart"
TaskLifecycleHookPoststop = "poststop"
)

type TaskLifecycle struct {
Expand Down
24 changes: 22 additions & 2 deletions client/allocrunner/alloc_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,26 @@ func (ar *allocRunner) shouldRun() bool {

// runTasks is used to run the task runners and block until they exit.
func (ar *allocRunner) runTasks() {
// Start all tasks
for _, task := range ar.tasks {
go task.Run()
}

// Block on all tasks except poststop tasks
for _, task := range ar.tasks {
<-task.WaitCh()
if !task.IsPoststopTask() {
<-task.WaitCh()
}
}

// Signal poststop tasks to proceed to main runtime
ar.taskHookCoordinator.StartPoststopTasks()

// Wait for poststop tasks to finish before proceeding
for _, task := range ar.tasks {
if task.IsPoststopTask() {
<-task.WaitCh()
}
}
}

Expand Down Expand Up @@ -485,6 +499,10 @@ func (ar *allocRunner) handleTaskStateUpdates() {
state := tr.TaskState()
states[name] = state

if tr.IsPoststopTask() {
continue
}

// Capture live task runners in case we need to kill them
if state.State != structs.TaskStateDead {
liveRunners = append(liveRunners, tr)
Expand Down Expand Up @@ -535,6 +553,7 @@ func (ar *allocRunner) handleTaskStateUpdates() {
// prevent looping before TaskRunners have transitioned
// to Dead.
for _, tr := range liveRunners {
ar.logger.Info("killing task: ", tr.Task().Name)
select {
case <-tr.WaitCh():
case <-ar.waitCh:
Expand Down Expand Up @@ -586,7 +605,8 @@ func (ar *allocRunner) killTasks() map[string]*structs.TaskState {
// Kill the rest concurrently
wg := sync.WaitGroup{}
for name, tr := range ar.tasks {
if tr.IsLeader() {
// Filter out poststop tasks so they run after all the other tasks are killed
if tr.IsLeader() || tr.IsPoststopTask() {
continue
}

Expand Down
83 changes: 83 additions & 0 deletions client/allocrunner/alloc_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,89 @@ func TestAllocRunner_TaskMain_KillTG(t *testing.T) {
})
}

// TestAllocRunner_Lifecycle_Poststop asserts that a service job with 1
// postop lifecycle hook starts all 3 tasks, only
// the ephemeral one finishes, and the other 2 exit when the alloc is stopped.
func TestAllocRunner_Lifecycle_Poststop(t *testing.T) {
alloc := mock.LifecycleAlloc()
tr := alloc.AllocatedResources.Tasks[alloc.Job.TaskGroups[0].Tasks[0].Name]

alloc.Job.Type = structs.JobTypeService
mainTask := alloc.Job.TaskGroups[0].Tasks[0]
mainTask.Config["run_for"] = "100s"

ephemeralTask := alloc.Job.TaskGroups[0].Tasks[1]
ephemeralTask.Name = "quit"
ephemeralTask.Lifecycle.Hook = structs.TaskLifecycleHookPoststop
ephemeralTask.Config["run_for"] = "10s"

alloc.Job.TaskGroups[0].Tasks = []*structs.Task{mainTask, ephemeralTask}
alloc.AllocatedResources.Tasks = map[string]*structs.AllocatedTaskResources{
mainTask.Name: tr,
ephemeralTask.Name: tr,
}

conf, cleanup := testAllocRunnerConfig(t, alloc)
defer cleanup()
ar, err := NewAllocRunner(conf)
require.NoError(t, err)
defer destroy(ar)
go ar.Run()

upd := conf.StateUpdater.(*MockStateUpdater)

// Wait for main task to be running
testutil.WaitForResult(func() (bool, error) {
last := upd.Last()
if last == nil {
return false, fmt.Errorf("No updates")
}

if last.ClientStatus != structs.AllocClientStatusRunning {
return false, fmt.Errorf("expected alloc to be running not %s", last.ClientStatus)
}

if s := last.TaskStates[mainTask.Name].State; s != structs.TaskStateRunning {
return false, fmt.Errorf("expected main task to be running not %s", s)
}

if s := last.TaskStates[ephemeralTask.Name].State; s != structs.TaskStatePending {
return false, fmt.Errorf("expected ephemeral task to be pending not %s", s)
}

return true, nil
}, func(err error) {
t.Fatalf("error waiting for initial state:\n%v", err)
})

// Tell the alloc to stop
stopAlloc := alloc.Copy()
stopAlloc.DesiredStatus = structs.AllocDesiredStatusStop
ar.Update(stopAlloc)

// Wait for main task to die & poststop task to run.
testutil.WaitForResult(func() (bool, error) {
last := upd.Last()

if last.ClientStatus != structs.AllocClientStatusRunning {
return false, fmt.Errorf("expected alloc to be running not %s", last.ClientStatus)
}

if s := last.TaskStates[mainTask.Name].State; s != structs.TaskStateDead {
return false, fmt.Errorf("expected main task to be dead not %s", s)
}

if s := last.TaskStates[ephemeralTask.Name].State; s != structs.TaskStateRunning {
return false, fmt.Errorf("expected poststop task to be running not %s", s)
}

return true, nil
}, func(err error) {
t.Fatalf("error waiting for initial state:\n%v", err)
})

}

func TestAllocRunner_TaskGroup_ShutdownDelay(t *testing.T) {
t.Parallel()

Expand Down
44 changes: 41 additions & 3 deletions client/allocrunner/task_hook_coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
)

// TaskHookCoordinator helps coordinate when main start tasks can launch
// TaskHookCoordinator helps coordinate when mainTasks start tasks can launch
// namely after all Prestart Tasks have run, and after all BlockUntilCompleted have completed
type taskHookCoordinator struct {
logger hclog.Logger

// constant for quickly starting all prestart tasks
closedCh chan struct{}

// Each context is used to gate task runners launching the tasks. A task
// runner waits until the context associated its lifecycle context is
// done/cancelled.
mainTaskCtx context.Context
mainTaskCtxCancel func()

poststartTaskCtx context.Context
poststartTaskCtxCancel func()
poststopTaskCtx context.Context
poststopTaskCtxCancel context.CancelFunc

prestartSidecar map[string]struct{}
prestartEphemeral map[string]struct{}
mainTasksPending map[string]struct{}
mainTasksRunning map[string]struct{} // poststop: main tasks running -> finished
mainTasksPending map[string]struct{} // poststart: main tasks pending -> running
}

func newTaskHookCoordinator(logger hclog.Logger, tasks []*structs.Task) *taskHookCoordinator {
Expand All @@ -32,6 +39,7 @@ func newTaskHookCoordinator(logger hclog.Logger, tasks []*structs.Task) *taskHoo

mainTaskCtx, mainCancelFn := context.WithCancel(context.Background())
poststartTaskCtx, poststartCancelFn := context.WithCancel(context.Background())
poststopTaskCtx, poststopTaskCancelFn := context.WithCancel(context.Background())

c := &taskHookCoordinator{
logger: logger,
Expand All @@ -40,9 +48,12 @@ func newTaskHookCoordinator(logger hclog.Logger, tasks []*structs.Task) *taskHoo
mainTaskCtxCancel: mainCancelFn,
prestartSidecar: map[string]struct{}{},
prestartEphemeral: map[string]struct{}{},
mainTasksRunning: map[string]struct{}{},
mainTasksPending: map[string]struct{}{},
poststartTaskCtx: poststartTaskCtx,
poststartTaskCtxCancel: poststartCancelFn,
poststopTaskCtx: poststopTaskCtx,
poststopTaskCtxCancel: poststopTaskCancelFn,
}
c.setTasks(tasks)
return c
Expand All @@ -53,6 +64,7 @@ func (c *taskHookCoordinator) setTasks(tasks []*structs.Task) {

if task.Lifecycle == nil {
c.mainTasksPending[task.Name] = struct{}{}
c.mainTasksRunning[task.Name] = struct{}{}
continue
}

Expand All @@ -65,6 +77,8 @@ func (c *taskHookCoordinator) setTasks(tasks []*structs.Task) {
}
case structs.TaskLifecycleHookPoststart:
// Poststart hooks don't need to be tracked.
case structs.TaskLifecycleHookPoststop:
// Poststop hooks don't need to be tracked.
default:
c.logger.Error("invalid lifecycle hook", "task", task.Name, "hook", task.Lifecycle.Hook)
}
Expand All @@ -79,6 +93,10 @@ func (c *taskHookCoordinator) hasPrestartTasks() bool {
return len(c.prestartSidecar)+len(c.prestartEphemeral) > 0
}

func (c *taskHookCoordinator) hasRunningMainTasks() bool {
return len(c.mainTasksRunning) > 0
}

func (c *taskHookCoordinator) hasPendingMainTasks() bool {
return len(c.mainTasksPending) > 0
}
Expand All @@ -94,7 +112,11 @@ func (c *taskHookCoordinator) startConditionForTask(task *structs.Task) <-chan s
return c.closedCh
case structs.TaskLifecycleHookPoststart:
return c.poststartTaskCtx.Done()
case structs.TaskLifecycleHookPoststop:
return c.poststopTaskCtx.Done()
default:
// it should never have a lifecycle stanza w/o a hook, so report an error but allow the task to start normally
c.logger.Error("invalid lifecycle hook", "task", task.Name, "hook", task.Lifecycle.Hook)
return c.mainTaskCtx.Done()
}
}
Expand All @@ -119,6 +141,16 @@ func (c *taskHookCoordinator) taskStateUpdated(states map[string]*structs.TaskSt
delete(c.prestartEphemeral, task)
}

for task := range c.mainTasksRunning {
st := states[task]

if st == nil || st.State != structs.TaskStateDead {
continue
}

delete(c.mainTasksRunning, task)
}

for task := range c.mainTasksPending {
st := states[task]
if st == nil || st.StartedAt.IsZero() {
Expand All @@ -128,14 +160,20 @@ func (c *taskHookCoordinator) taskStateUpdated(states map[string]*structs.TaskSt
delete(c.mainTasksPending, task)
}

// everything well
if !c.hasPrestartTasks() {
c.mainTaskCtxCancel()
}

if !c.hasPendingMainTasks() {
c.poststartTaskCtxCancel()
}
if !c.hasRunningMainTasks() {
c.poststopTaskCtxCancel()
}
}

func (c *taskHookCoordinator) StartPoststopTasks() {
c.poststopTaskCtxCancel()
}

// hasNonSidecarTasks returns false if all the passed tasks are sidecar tasks
Expand Down
5 changes: 5 additions & 0 deletions client/allocrunner/taskrunner/restarts/restarts.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func NewRestartTracker(policy *structs.RestartPolicy, jobType string, tlc *struc
onSuccess = tlc.Sidecar
}

// Poststop should never be restarted on success
if tlc != nil && tlc.Hook == structs.TaskLifecycleHookPoststop {
onSuccess = false
}

return &RestartTracker{
startTime: time.Now(),
onSuccess: onSuccess,
Expand Down
15 changes: 14 additions & 1 deletion client/allocrunner/taskrunner/task_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,15 @@ func (tr *TaskRunner) Run() {

select {
case <-tr.startConditionMetCtx:
tr.logger.Debug("lifecycle start condition has been met, proceeding")
// yay proceed
case <-tr.killCtx.Done():
case <-tr.shutdownCtx.Done():
return
}

MAIN:
for !tr.Alloc().TerminalStatus() {
for !tr.shouldShutdown() {
select {
case <-tr.killCtx.Done():
break MAIN
Expand Down Expand Up @@ -625,6 +626,18 @@ MAIN:
tr.logger.Debug("task run loop exiting")
}

func (tr *TaskRunner) shouldShutdown() bool {
if tr.alloc.ClientTerminalStatus() {
return true
}

if !tr.IsPoststopTask() && tr.alloc.ServerTerminalStatus() {
return true
}

return false
}

// handleTaskExitResult handles the results returned by the task exiting. If
// retryWait is true, the caller should attempt to wait on the task again since
// it has not actually finished running. This can happen if the driver plugin
Expand Down
5 changes: 5 additions & 0 deletions client/allocrunner/taskrunner/task_runner_getters.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ func (tr *TaskRunner) IsLeader() bool {
return tr.taskLeader
}

// IsPoststopTask returns true if this task is a poststop task in its task group.
func (tr *TaskRunner) IsPoststopTask() bool {
return tr.Task().Lifecycle != nil && tr.Task().Lifecycle.Hook == structs.TaskLifecycleHookPoststop
}

func (tr *TaskRunner) Task() *structs.Task {
tr.taskLock.RLock()
defer tr.taskLock.RUnlock()
Expand Down
3 changes: 1 addition & 2 deletions client/allocrunner/taskrunner/task_runner_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ func (tr *TaskRunner) emitHookError(err error, hookName string) {
func (tr *TaskRunner) prestart() error {
// Determine if the allocation is terminal and we should avoid running
// prestart hooks.
alloc := tr.Alloc()
if alloc.TerminalStatus() {
if tr.shouldShutdown() {
tr.logger.Trace("skipping prestart hooks since allocation is terminal")
return nil
}
Expand Down
Loading

0 comments on commit 2eca731

Please sign in to comment.