diff --git a/pkg/tasks/worker/worker.go b/pkg/tasks/worker/worker.go index b5a8d2c87..216029c54 100644 --- a/pkg/tasks/worker/worker.go +++ b/pkg/tasks/worker/worker.go @@ -75,7 +75,7 @@ func newWorker(config workerConfig, metrics *m.Metrics) worker { func (w *worker) start(ctx context.Context) { log.Logger.Info().Msg("Starting worker") defer w.workerWg.Done() - defer recoverOnPanic(log.Logger) + defer w.recoverOnPanic(log.Logger) w.readyChan <- struct{}{} @@ -127,7 +127,7 @@ func (w *worker) start(ctx context.Context) { func (w *worker) dequeue(ctx context.Context) (*models.TaskInfo, error) { logger := logForTask(w.runningTask) - defer recoverOnPanic(*logger) + defer w.recoverOnPanic(*logger) info, err := w.queue.Dequeue(ctx, w.taskTypes) if err != nil { @@ -138,7 +138,10 @@ func (w *worker) dequeue(ctx context.Context) (*models.TaskInfo, error) { w.readyChan <- struct{}{} return nil, err } - w.metrics.RecordMessageLatency(*info.Queued) + if w.metrics != nil { + w.metrics.RecordMessageLatency(*info.Queued) + } + w.runningTask.setTaskInfo(info) logForTask(w.runningTask).Info().Msg("[Dequeued Task]") @@ -147,7 +150,7 @@ func (w *worker) dequeue(ctx context.Context) (*models.TaskInfo, error) { func (w *worker) requeue(id uuid.UUID) error { logger := logForTask(w.runningTask) - defer recoverOnPanic(*logger) + defer w.recoverOnPanic(*logger) err := w.queue.Requeue(id) if err != nil { @@ -160,7 +163,7 @@ func (w *worker) requeue(id uuid.UUID) error { // process calls the handler for the task specified by taskInfo, finishes the task, then marks worker as ready for new task func (w *worker) process(ctx context.Context, taskInfo *models.TaskInfo) { logger := zerolog.Ctx(ctx) - defer recoverOnPanic(*logger) + defer w.recoverOnPanic(*logger) if handler, ok := w.handlers[taskInfo.Typename]; ok { var finishStr string @@ -174,15 +177,15 @@ func (w *worker) process(ctx context.Context, taskInfo *models.TaskInfo) { if errors.Is(handlerErr, context.Canceled) { finishStr = "task canceled" - w.metrics.RecordMessageResult(true) + w.recordMessageResult(true) logger.Info().Msgf("[Finished Task] %v", finishStr) } else if handlerErr != nil { finishStr = fmt.Sprintf("task failed with error: %v", handlerErr) - w.metrics.RecordMessageResult(false) + w.recordMessageResult(false) logger.Warn().Msgf("[Finished Task] %v", finishStr) } else { finishStr = "task completed" - w.metrics.RecordMessageResult(true) + w.recordMessageResult(true) logger.Info().Msgf("[Finished Task] %v", finishStr) } @@ -194,16 +197,35 @@ func (w *worker) process(ctx context.Context, taskInfo *models.TaskInfo) { w.readyChan <- struct{}{} } +func (w *worker) recordMessageResult(res bool) { + if w.metrics != nil { + w.metrics.RecordMessageResult(res) + } +} func (w *worker) stop() { w.stopChan <- struct{}{} } // Catches a panic so that only the surrounding function is exited -func recoverOnPanic(logger zerolog.Logger) { +func (w *worker) recoverOnPanic(logger zerolog.Logger) { var err error if r := recover(); r != nil { err, _ = r.(error) logger.Error().Err(err).Stack().Msgf("recovered panic in worker with error: %v", err) + logger.Info().Msgf("[Finished Task] task failed (panic)") + + if w.runningTask != nil { + tErr := w.queue.Finish(w.runningTask.id, err) + if tErr != nil { + log.Error().Err(tErr).Msgf("Could not update task during panic recovery, original error: %v", err.Error()) + } + + if w.runningTask.taskCancelFunc != nil { + w.runningTask.taskCancelFunc(queue.ErrNotRunning) + } + w.runningTask.clear() + } + w.readyChan <- struct{}{} } } diff --git a/pkg/tasks/worker/worker_pool_test.go b/pkg/tasks/worker/worker_pool_test.go index f15299a29..6486bc49d 100644 --- a/pkg/tasks/worker/worker_pool_test.go +++ b/pkg/tasks/worker/worker_pool_test.go @@ -5,7 +5,10 @@ import ( "testing" "time" + "github.com/content-services/content-sources-backend/pkg/models" "github.com/content-services/content-sources-backend/pkg/tasks/queue" + uuid2 "github.com/google/uuid" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "go.uber.org/goleak" ) @@ -30,11 +33,11 @@ func (s *WorkerSuite) TestStartStopWorkers() { s.T().Setenv("TASKING_WORKER_COUNT", "3") ctx, cancelFunc := context.WithCancel(context.Background()) + mockQueue.On("Dequeue", mock.Anything, []string(nil)).Return(&models.TaskInfo{}, nil) + mockQueue.On("ListenForCancel", mock.Anything, uuid2.Nil, mock.Anything).Return(nil, nil) - mockQueue.On("Dequeue", ctx, []string(nil)).Times(3).Return(nil, nil) - - workerPool.StartWorkers(context.Background()) - time.Sleep(time.Millisecond * 5) + workerPool.StartWorkers(ctx) + time.Sleep(time.Millisecond * 2) workerPool.Stop() cancelFunc() }