diff --git a/client/allocrunner/alloc_runner_unix_test.go b/client/allocrunner/alloc_runner_unix_test.go index e0608efc75cd..8c31832e3c43 100644 --- a/client/allocrunner/alloc_runner_unix_test.go +++ b/client/allocrunner/alloc_runner_unix_test.go @@ -131,3 +131,87 @@ func TestAllocRunner_Restore_RunningTerminal(t *testing.T) { require.Equal(t, events[2].Type, structs.TaskStarted) require.Equal(t, events[3].Type, structs.TaskTerminated) } + +// TestAllocRunner_Restore_CompletedBatch asserts that restoring a completed +// batch alloc doesn't run it again +func TestAllocRunner_Restore_CompletedBatch(t *testing.T) { + t.Parallel() + + // 1. Run task and wait for it to complete + // 2. Start new alloc runner + // 3. Assert task didn't run again + + alloc := mock.Alloc() + alloc.Job.Type = structs.JobTypeBatch + task := alloc.Job.TaskGroups[0].Tasks[0] + task.Driver = "mock_driver" + task.Config = map[string]interface{}{ + "run_for": "2ms", + } + + conf, cleanup := testAllocRunnerConfig(t, alloc.Copy()) + defer cleanup() + + // Maintain state for subsequent run + conf.StateDB = state.NewMemDB(conf.Logger) + + // Start and wait for task to be running + ar, err := NewAllocRunner(conf) + require.NoError(t, err) + go ar.Run() + defer destroy(ar) + + testutil.WaitForResult(func() (bool, error) { + s := ar.AllocState() + if s.ClientStatus != structs.AllocClientStatusComplete { + return false, fmt.Errorf("expected complete, got %s", s.ClientStatus) + } + return true, nil + }, func(err error) { + require.NoError(t, err) + }) + + // once job finishes, it shouldn't run again + require.False(t, ar.shouldRun()) + initialRunEvents := ar.AllocState().TaskStates[task.Name].Events + require.Len(t, initialRunEvents, 4) + + ls, ts, err := conf.StateDB.GetTaskRunnerState(alloc.ID, task.Name) + require.NoError(t, err) + require.NotNil(t, ls) + require.Equal(t, structs.TaskStateDead, ts.State) + + // Start a new alloc runner and assert it gets stopped + conf2, cleanup2 := testAllocRunnerConfig(t, alloc) + defer cleanup2() + + // Use original statedb to maintain hook state + conf2.StateDB = conf.StateDB + + // Restore, start, and wait for task to be killed + ar2, err := NewAllocRunner(conf2) + require.NoError(t, err) + + require.NoError(t, ar2.Restore()) + + go ar2.Run() + defer destroy(ar2) + + // AR waitCh must be closed even when task doesn't run again + select { + case <-ar2.WaitCh(): + case <-time.After(10 * time.Second): + require.Fail(t, "alloc.waitCh wasn't closed") + } + + // TR waitCh must be closed too! + select { + case <-ar2.tasks[task.Name].WaitCh(): + case <-time.After(10 * time.Second): + require.Fail(t, "tr.waitCh wasn't closed") + } + + // Assert that events are unmodified, which they would if task re-run + events := ar2.AllocState().TaskStates[task.Name].Events + require.Equal(t, initialRunEvents, events) +} diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index b71c3bd4df87..233a97b116c6 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -394,6 +394,24 @@ func (tr *TaskRunner) Run() { defer close(tr.waitCh) var result *drivers.ExitResult + tr.stateLock.RLock() + dead := tr.state.State == structs.TaskStateDead + tr.stateLock.RUnlock() + + // if restoring a dead task, ensure that task is cleared and all post hooks + // are called without additional state updates + if dead { + // do cleanup functions without emitting any additional events/work + // to handle cases where we restored a dead task where client terminated + // after task finished before completing post-run actions. + tr.clearDriverHandle() + tr.stateUpdater.TaskStateUpdated() + if err := tr.stop(); err != nil { + tr.logger.Error("stop failed on terminal task", "error", err) + } + return + } + // Updates are handled asynchronously with the other hooks but each // triggered update - whether due to alloc updates or a new vault token // - should be handled serially. @@ -899,7 +917,7 @@ func (tr *TaskRunner) Restore() error { } alloc := tr.Alloc() - if alloc.TerminalStatus() || alloc.Job.Type == structs.JobTypeSystem { + if tr.state.State == structs.TaskStateDead || alloc.TerminalStatus() || alloc.Job.Type == structs.JobTypeSystem { return nil }