Skip to content

Commit

Permalink
Merge pull request #2449 from ulzha/capturer
Browse files Browse the repository at this point in the history
Allow to inject a context manager around TaskProcess.run
  • Loading branch information
NatashaL authored Jul 11, 2018
2 parents bdfcb6c + ed27ffc commit b791d57
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 3 deletions.
32 changes: 30 additions & 2 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import collections
import getpass
import importlib
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -59,7 +60,7 @@
from luigi.task import Task, flatten, getpaths, Config
from luigi.task_register import TaskClassException
from luigi.task_status import RUNNING
from luigi.parameter import FloatParameter, IntParameter, BoolParameter
from luigi.parameter import BoolParameter, FloatParameter, IntParameter, Parameter

try:
import simplejson as json
Expand Down Expand Up @@ -258,6 +259,26 @@ def terminate(self):
return super(TaskProcess, self).terminate()


# This code and the task_process_context config key currently feels a bit ad-hoc.
# Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897
class ContextManagedTaskProcess(TaskProcess):
def __init__(self, context, *args, **kwargs):
super(ContextManagedTaskProcess, self).__init__(*args, **kwargs)
self.context = context

def run(self):
if self.context:
logger.debug('Importing module and instantiating ' + self.context)
module_path, class_name = self.context.rsplit('.', 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)

with cls(self):
super(ContextManagedTaskProcess, self).run()
else:
super(ContextManagedTaskProcess, self).run()


class TaskStatusReporter(object):
"""
Reports task status information to the scheduler.
Expand Down Expand Up @@ -419,6 +440,12 @@ class worker(Config):
force_multiprocessing = BoolParameter(default=False,
description='If true, use multiprocessing also when '
'running with 1 worker')
task_process_context = Parameter(default=None,
description='If set to a fully qualified class name, the class will '
'be instantiated with a TaskProcess as its constructor parameter and '
'applied as a context manager around its run() call, so this can be '
'used for obtaining high level customizable monitoring or logging of '
'each individual Task run.')


class KeepAliveThread(threading.Thread):
Expand Down Expand Up @@ -966,7 +993,8 @@ def _create_task_process(self, task):
message_queue = multiprocessing.Queue() if task.accepts_messages else None
reporter = TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue)
use_multiprocessing = self._config.force_multiprocessing or bool(self.worker_processes > 1)
return TaskProcess(
return ContextManagedTaskProcess(
self._config.task_process_context,
task, self._id, self._task_result_queue, reporter,
use_multiprocessing=use_multiprocessing,
worker_timeout=self._config.timeout,
Expand Down
70 changes: 70 additions & 0 deletions test/worker_task_process_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from helpers import LuigiTestCase, temporary_unloaded_module
import luigi
from luigi.worker import Worker
import multiprocessing


class ContextManagedTaskProcessTest(LuigiTestCase):

def _test_context_manager(self, force_multiprocessing):
CONTEXT_MANAGER_MODULE = b'''
class MyContextManager(object):
def __init__(self, task_process):
self.task = task_process.task
def __enter__(self):
assert not self.task.run_event.is_set(), "the task should not have run yet"
self.task.enter_event.set()
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
assert self.task.run_event.is_set(), "the task should have run"
self.task.exit_event.set()
'''

class DummyEventRecordingTask(luigi.Task):
def __init__(self, *args, **kwargs):
self.enter_event = multiprocessing.Event()
self.exit_event = multiprocessing.Event()
self.run_event = multiprocessing.Event()
super(DummyEventRecordingTask, self).__init__(*args, **kwargs)

def run(self):
assert self.enter_event.is_set(), "the context manager should have been entered"
assert not self.exit_event.is_set(), "the context manager should not have been exited yet"
assert not self.run_event.is_set(), "the task should not have run yet"
self.run_event.set()

def complete(self):
return self.run_event.is_set()

with temporary_unloaded_module(CONTEXT_MANAGER_MODULE) as module_name:
t = DummyEventRecordingTask()
w = Worker(task_process_context=module_name + '.MyContextManager',
force_multiprocessing=force_multiprocessing)
w.add(t)
self.assertTrue(w.run())
self.assertTrue(t.complete())
self.assertTrue(t.enter_event.is_set())
self.assertTrue(t.exit_event.is_set())

def test_context_manager_without_multiprocessing(self):
self._test_context_manager(False)

def test_context_manager_with_multiprocessing(self):
self._test_context_manager(True)
2 changes: 1 addition & 1 deletion test/worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,7 @@ def complete(self):

class WorkerPurgeEventHandlerTest(unittest.TestCase):

@mock.patch('luigi.worker.TaskProcess')
@mock.patch('luigi.worker.ContextManagedTaskProcess')
def test_process_killed_handler(self, task_proc):
result = []

Expand Down

0 comments on commit b791d57

Please sign in to comment.