diff --git a/luigi/worker.py b/luigi/worker.py index a6e59c75fd..a051dcf88c 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -30,6 +30,7 @@ import collections import getpass +import importlib import logging import multiprocessing import os @@ -59,7 +60,7 @@ from luigi.task import Task, flatten, getpaths, Config from luigi.task_register import TaskClassException from luigi.task_status import RUNNING -from luigi.parameter import FloatParameter, IntParameter, BoolParameter +from luigi.parameter import BoolParameter, FloatParameter, IntParameter, Parameter try: import simplejson as json @@ -258,6 +259,27 @@ def terminate(self): return super(TaskProcess, self).terminate() +# TODO be composable with arbitrarily many custom context managers? +# Introduce a convention shared for extension points other than TaskProcess? +# Use https://docs.openstack.org/stevedore? +class ContextManagedTaskProcess(TaskProcess): + def __init__(self, context, *args, **kwargs): + super(ContextManagedTaskProcess, self).__init__(*args, **kwargs) + self.context = context + + def run(self): + if self.context: + logger.debug('Instantiating ' + self.context) + module_path, class_name = self.context.rsplit('.', 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + + with cls(self): + super(ContextManagedTaskProcess, self).run() + else: + super(ContextManagedTaskProcess, self).run() + + class TaskStatusReporter(object): """ Reports task status information to the scheduler. @@ -419,6 +441,12 @@ class worker(Config): force_multiprocessing = BoolParameter(default=False, description='If true, use multiprocessing also when ' 'running with 1 worker') + task_process_context = Parameter(default=None, + description='If set to a fully qualified class name, the class will ' + 'be instantiated with a TaskProcess as its constructor parameter and ' + 'applied as a context manager around its run() call, so this can be ' + 'used for obtaining high level customizable monitoring or logging of ' + 'each individual Task run.') class KeepAliveThread(threading.Thread): @@ -966,7 +994,8 @@ def _create_task_process(self, task): message_queue = multiprocessing.Queue() if task.accepts_messages else None reporter = TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue) use_multiprocessing = self._config.force_multiprocessing or bool(self.worker_processes > 1) - return TaskProcess( + return ContextManagedTaskProcess( + self._config.task_process_context, task, self._id, self._task_result_queue, reporter, use_multiprocessing=use_multiprocessing, worker_timeout=self._config.timeout, diff --git a/test/worker_task_process_test.py b/test/worker_task_process_test.py new file mode 100644 index 0000000000..79c245ad77 --- /dev/null +++ b/test/worker_task_process_test.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from helpers import LuigiTestCase, temporary_unloaded_module +import luigi +from luigi.worker import Worker +import multiprocessing + + +class ContextManagedTaskProcessTest(LuigiTestCase): + + def _test_context_manager(self, force_multiprocessing): + CONTEXT_MANAGER_MODULE = b''' +class MyContextManager(object): + def __init__(self, task_process): + self.task = task_process.task + def __enter__(self): + assert not self.task.run_event.is_set(), "the task should not have run yet" + self.task.enter_event.set() + return self + def __exit__(self, exc_type=None, exc_value=None, traceback=None): + assert self.task.run_event.is_set(), "the task should have run" + self.task.exit_event.set() +''' + + class DummyEventRecordingTask(luigi.Task): + def __init__(self, *args, **kwargs): + self.enter_event = multiprocessing.Event() + self.exit_event = multiprocessing.Event() + self.run_event = multiprocessing.Event() + super(DummyEventRecordingTask, self).__init__(*args, **kwargs) + + def run(self): + assert self.enter_event.is_set(), "the context manager should have been entered" + assert not self.exit_event.is_set(), "the context manager should not have been exited yet" + assert not self.run_event.is_set(), "the task should have run yet" + self.run_event.set() + + def complete(self): + return self.run_event.is_set() + + with temporary_unloaded_module(CONTEXT_MANAGER_MODULE) as module_name: + t = DummyEventRecordingTask() + w = Worker(task_process_context=module_name + '.MyContextManager', + force_multiprocessing=force_multiprocessing) + w.add(t) + self.assertTrue(w.run()) + self.assertTrue(t.complete()) + self.assertTrue(t.enter_event.is_set()) + self.assertTrue(t.exit_event.is_set()) + + def test_context_manager_without_multiprocessing(self): + self._test_context_manager(False) + + def test_context_manager_with_multiprocessing(self): + self._test_context_manager(True) diff --git a/test/worker_test.py b/test/worker_test.py index 5658a0531c..cf8d37a77a 100644 --- a/test/worker_test.py +++ b/test/worker_test.py @@ -1823,7 +1823,7 @@ def complete(self): class WorkerPurgeEventHandlerTest(unittest.TestCase): - @mock.patch('luigi.worker.TaskProcess') + @mock.patch('luigi.worker.ContextManagedTaskProcess') def test_process_killed_handler(self, task_proc): result = []