Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to inject a context manager around TaskProcess.run #2449

Merged
merged 2 commits into from
Jul 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import collections
import getpass
import importlib
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -59,7 +60,7 @@
from luigi.task import Task, flatten, getpaths, Config
from luigi.task_register import TaskClassException
from luigi.task_status import RUNNING
from luigi.parameter import FloatParameter, IntParameter, BoolParameter
from luigi.parameter import BoolParameter, FloatParameter, IntParameter, Parameter

try:
import simplejson as json
Expand Down Expand Up @@ -258,6 +259,26 @@ def terminate(self):
return super(TaskProcess, self).terminate()


# This code and the task_process_context config key currently feels a bit ad-hoc.
# Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897
class ContextManagedTaskProcess(TaskProcess):
def __init__(self, context, *args, **kwargs):
super(ContextManagedTaskProcess, self).__init__(*args, **kwargs)
self.context = context

def run(self):
if self.context:
logger.debug('Importing module and instantiating ' + self.context)
module_path, class_name = self.context.rsplit('.', 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)

with cls(self):
super(ContextManagedTaskProcess, self).run()
else:
super(ContextManagedTaskProcess, self).run()


class TaskStatusReporter(object):
"""
Reports task status information to the scheduler.
Expand Down Expand Up @@ -419,6 +440,12 @@ class worker(Config):
force_multiprocessing = BoolParameter(default=False,
description='If true, use multiprocessing also when '
'running with 1 worker')
task_process_context = Parameter(default=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider making this a "list" by making a comma-separate name of classnames?

But I guess users can make their single context-manager itself have multiple ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly, that would be the next step I could think of; about to add in a separate PR.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will throw a warning because None is not allowed in string parameter:

UserWarning: Parameter "task_process_context" with value "None" is not of type string.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will throw a warning because None is not allowed in string parameter:

UserWarning: Parameter "task_process_context" with value "None" is not of type string.

I wonder why can't it be defaulted to "" instead, like in

task_process_context = Parameter(default="",

@ulzha any thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, you can leave the None default and make it an OptionalParameter

description='If set to a fully qualified class name, the class will '
'be instantiated with a TaskProcess as its constructor parameter and '
'applied as a context manager around its run() call, so this can be '
'used for obtaining high level customizable monitoring or logging of '
'each individual Task run.')


class KeepAliveThread(threading.Thread):
Expand Down Expand Up @@ -966,7 +993,8 @@ def _create_task_process(self, task):
message_queue = multiprocessing.Queue() if task.accepts_messages else None
reporter = TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue)
use_multiprocessing = self._config.force_multiprocessing or bool(self.worker_processes > 1)
return TaskProcess(
return ContextManagedTaskProcess(
self._config.task_process_context,
task, self._id, self._task_result_queue, reporter,
use_multiprocessing=use_multiprocessing,
worker_timeout=self._config.timeout,
Expand Down
70 changes: 70 additions & 0 deletions test/worker_task_process_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from helpers import LuigiTestCase, temporary_unloaded_module
import luigi
from luigi.worker import Worker
import multiprocessing


class ContextManagedTaskProcessTest(LuigiTestCase):

def _test_context_manager(self, force_multiprocessing):
CONTEXT_MANAGER_MODULE = b'''
class MyContextManager(object):
def __init__(self, task_process):
self.task = task_process.task
def __enter__(self):
assert not self.task.run_event.is_set(), "the task should not have run yet"
self.task.enter_event.set()
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
assert self.task.run_event.is_set(), "the task should have run"
self.task.exit_event.set()
'''

class DummyEventRecordingTask(luigi.Task):
def __init__(self, *args, **kwargs):
self.enter_event = multiprocessing.Event()
self.exit_event = multiprocessing.Event()
self.run_event = multiprocessing.Event()
super(DummyEventRecordingTask, self).__init__(*args, **kwargs)

def run(self):
assert self.enter_event.is_set(), "the context manager should have been entered"
assert not self.exit_event.is_set(), "the context manager should not have been exited yet"
assert not self.run_event.is_set(), "the task should not have run yet"
self.run_event.set()

def complete(self):
return self.run_event.is_set()

with temporary_unloaded_module(CONTEXT_MANAGER_MODULE) as module_name:
t = DummyEventRecordingTask()
w = Worker(task_process_context=module_name + '.MyContextManager',
force_multiprocessing=force_multiprocessing)
w.add(t)
self.assertTrue(w.run())
self.assertTrue(t.complete())
self.assertTrue(t.enter_event.is_set())
self.assertTrue(t.exit_event.is_set())

def test_context_manager_without_multiprocessing(self):
self._test_context_manager(False)

def test_context_manager_with_multiprocessing(self):
self._test_context_manager(True)
2 changes: 1 addition & 1 deletion test/worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,7 @@ def complete(self):

class WorkerPurgeEventHandlerTest(unittest.TestCase):

@mock.patch('luigi.worker.TaskProcess')
@mock.patch('luigi.worker.ContextManagedTaskProcess')
def test_process_killed_handler(self, task_proc):
result = []

Expand Down