Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow release of resources during task running. #2346

Merged
merged 13 commits into from
Apr 8, 2018
Merged
27 changes: 27 additions & 0 deletions doc/luigi_patterns.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,33 @@ the task parameters or other dynamic attributes:
Since, by default, resources have a usage limit of 1, no two instances of Task A
will now run if they have the same `important_file_name` property.

Decreasing resources of running tasks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

At scheduling time, the luigi scheduler needs to be aware of the maximum
resource consumption a task might have once it runs. For some tasks, however,
it can be beneficial to decrease the amount of consumed resources between two
steps within their run method (e.g. after some heavy computation). In this
case, a different task waiting for that particular resource can already be
scheduled.

.. code-block:: python

class A(luigi.Task):

# set maximum resources a priori
resources = {"some_resource": 3}

def run(self):
# do something
...

# decrease consumption of "some_resource" by one
self.decrease_running_resources({"some_resource": 1})

# continue with reduced resources
...

Monitoring task pipelines
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
39 changes: 35 additions & 4 deletions luigi/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,8 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
worker_id = worker
worker = self._update_worker(worker_id)

resources = {} if resources is None else resources.copy()

if retry_policy_dict is None:
retry_policy_dict = {}

Expand Down Expand Up @@ -815,7 +817,9 @@ def add_task(self, task_id=None, status=PENDING, runnable=True,
if status == RUNNING and not task.worker_running:
task.worker_running = worker_id
if batch_id:
task.resources_running = self._state.get_batch_running_tasks(batch_id)[0].resources_running
# copy resources_running of the first batch task
batch_tasks = self._state.get_batch_running_tasks(batch_id)
task.resources_running = batch_tasks[0].resources_running.copy()
task.time_running = time.time()

if tracking_url is not None or task.status != RUNNING:
Expand Down Expand Up @@ -970,8 +974,9 @@ def _used_resources(self):
used_resources = collections.defaultdict(int)
if self._resources is not None:
for task in self._state.get_active_tasks_by_status(RUNNING):
if getattr(task, 'resources_running', task.resources):
for resource, amount in six.iteritems(getattr(task, 'resources_running', task.resources)):
resources_running = getattr(task, "resources_running", task.resources)
if resources_running:
for resource, amount in six.iteritems(resources_running):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice code improvement :)

used_resources[resource] += amount
return used_resources

Expand Down Expand Up @@ -1175,7 +1180,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
elif best_task:
self._state.set_status(best_task, RUNNING, self._config)
best_task.worker_running = worker_id
best_task.resources_running = best_task.resources
best_task.resources_running = best_task.resources.copy()
best_task.time_running = time.time()
self._update_task_history(best_task, RUNNING, host=host)

Expand Down Expand Up @@ -1237,6 +1242,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None):
'name': task.family,
'priority': task.priority,
'resources': task.resources,
'resources_running': getattr(task, "resources_running", None),
'tracking_url': getattr(task, "tracking_url", None),
'status_message': getattr(task, "status_message", None),
'progress_percentage': getattr(task, "progress_percentage", None)
Expand Down Expand Up @@ -1521,6 +1527,31 @@ def get_task_progress_percentage(self, task_id):
else:
return {"taskId": task_id, "progressPercentage": None}

@rpc_method()
def decrease_running_task_resources(self, task_id, decrease_resources):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
if task.status != RUNNING:
return

def decrease(resources, decrease_resources):
for resource, decrease_amount in six.iteritems(decrease_resources):
if decrease_amount > 0 and resource in resources:
resources[resource] = max(0, resources[resource] - decrease_amount)

decrease(task.resources_running, decrease_resources)
if task.batch_id is not None:
for batch_task in self._state.get_batch_running_tasks(task.batch_id):
decrease(batch_task.resources_running, decrease_resources)

@rpc_method()
def get_running_task_resources(self, task_id):
if self._state.has_task(task_id):
task = self._state.get_task(task_id)
return {"taskId": task_id, "resources": getattr(task, "resources_running", None)}
else:
return {"taskId": task_id, "resources": None}

def _update_task_history(self, task, status, host=None):
try:
if status == DONE or status == FAILED:
Expand Down
2 changes: 1 addition & 1 deletion luigi/static/visualiser/js/visualiserApp.js
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function visualiserApp(luigi) {
taskParams: taskParams,
displayName: task.display_name,
priority: task.priority,
resources: JSON.stringify(task.resources).replace(/,"/g, ', "'),
resources: JSON.stringify(task.resources_running || task.resources).replace(/,"/g, ', "'),
displayTime: displayTime,
displayTimestamp: task.last_updated,
timeRunning: time_running,
Expand Down
3 changes: 2 additions & 1 deletion luigi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import copy
import functools

import luigi
from luigi import six

from luigi import parameter
Expand Down Expand Up @@ -686,7 +687,7 @@ def _dump(self):
pickle.dumps(self)
"""
unpicklable_properties = ('set_tracking_url', 'set_status_message', 'set_progress_percentage')
unpicklable_properties = tuple(luigi.worker.TaskProcess.forward_reporter_callbacks.values())
reserved_properties = {}
for property_name in unpicklable_properties:
if hasattr(self, property_name):
Expand Down
24 changes: 18 additions & 6 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ class TaskProcess(multiprocessing.Process):

Mainly for convenience since this is run in a separate process. """

# mapping of status_reporter methods to task callbacks that are added to the task
# before they actually run, and removed afterwards
forward_reporter_callbacks = {
"update_tracking_url": "set_tracking_url",
"update_status_message": "set_status_message",
"update_progress_percentage": "set_progress_percentage",
"decrease_running_resources": "decrease_running_resources",
}

def __init__(self, task, worker_id, result_queue, status_reporter,
use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True):
super(TaskProcess, self).__init__()
Expand All @@ -124,15 +133,15 @@ def __init__(self, task, worker_id, result_queue, status_reporter,
self.check_unfulfilled_deps = check_unfulfilled_deps

def _run_get_new_deps(self):
self.task.set_tracking_url = self.status_reporter.update_tracking_url
self.task.set_status_message = self.status_reporter.update_status_message
self.task.set_progress_percentage = self.status_reporter.update_progress_percentage
# set task callbacks before running
for reporter_attr, task_attr in six.iteritems(self.forward_reporter_callbacks):
setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr))

task_gen = self.task.run()

self.task.set_tracking_url = None
self.task.set_status_message = None
self.task.set_progress_percentage = None
# reset task callbacks
for reporter_attr, task_attr in six.iteritems(self.forward_reporter_callbacks):
setattr(self.task, task_attr, None)

if not isinstance(task_gen, types.GeneratorType):
return None
Expand Down Expand Up @@ -274,6 +283,9 @@ def update_status_message(self, message):
def update_progress_percentage(self, percentage):
self._scheduler.set_task_progress_percentage(self._task_id, percentage)

def decrease_running_resources(self, decrease_resources):
self._scheduler.decrease_running_task_resources(self._task_id, decrease_resources)


class SingleProcessPool(object):
"""
Expand Down
13 changes: 10 additions & 3 deletions test/scheduler_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,12 +408,12 @@ def test_set_batch_runner_max(self):
self.sch.add_task(worker=WORKER, task_id='A_2', status=DONE)
self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(DONE, '').keys()))

def _start_simple_batch(self, use_max=False, mark_running=True):
def _start_simple_batch(self, use_max=False, mark_running=True, resources=None):
self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a'])
self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'},
batchable=True)
batchable=True, resources=resources)
self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'},
batchable=True)
batchable=True, resources=resources)
response = self.sch.get_work(worker=WORKER)
if mark_running:
batch_id = response['batch_id']
Expand Down Expand Up @@ -496,6 +496,13 @@ def test_batch_update_progress(self):
for task_id in ('A_1', 'A_2', 'A_1_2'):
self.assertEqual(30, self.sch.get_task_progress_percentage(task_id)['progressPercentage'])

def test_batch_decrease_resources(self):
self.sch.update_resources(x=3)
self._start_simple_batch(resources={'x': 3})
self.sch.decrease_running_task_resources('A_1_2', {'x': 1})
for task_id in ('A_1', 'A_2', 'A_1_2'):
self.assertEqual(2, self.sch.get_running_task_resources(task_id)['resources']['x'])

def test_batch_tracking_url(self):
self._start_simple_batch()
self.sch.add_task(worker=WORKER, task_id='A_1_2', tracking_url='http://test.tracking.url/')
Expand Down
138 changes: 138 additions & 0 deletions test/task_running_resources_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import time
import signal
import multiprocessing
from contextlib import contextmanager

from helpers import unittest, RunOnceTask

import luigi
import luigi.server


class ResourceTestTask(RunOnceTask):

param = luigi.Parameter()
reduce_foo = luigi.BoolParameter()

def process_resources(self):
return {"foo": 2}

def run(self):
if self.reduce_foo:
self.decrease_running_resources({"foo": 1})

time.sleep(2)

super(ResourceTestTask, self).run()


class ResourceWrapperTask(RunOnceTask):

reduce_foo = ResourceTestTask.reduce_foo

def requires(self):
return [
ResourceTestTask(param="a", reduce_foo=self.reduce_foo),
ResourceTestTask(param="b"),
]


class LocalRunningResourcesTest(unittest.TestCase):

def test_resource_reduction(self):
# trivial resource reduction on local scheduler
# test the running_task_resources setter and getter
sch = luigi.scheduler.Scheduler(resources={"foo": 2})

with luigi.worker.Worker(scheduler=sch) as w:
task = ResourceTestTask(param="a", reduce_foo=True)

w.add(task)
w.run()

self.assertEqual(sch.get_running_task_resources(task.task_id)["resources"]["foo"], 1)


class ConcurrentRunningResourcesTest(unittest.TestCase):

def get_app(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this in a follow-up? You don't need it anymore right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup!

return luigi.server.app(luigi.scheduler.Scheduler())

def setUp(self):
super(ConcurrentRunningResourcesTest, self).setUp()

# run the luigi server in a new process and wait for its startup
self._process = multiprocessing.Process(target=luigi.server.run)
self._process.start()
time.sleep(0.5)

# configure the rpc scheduler, update the foo resource
self.sch = luigi.rpc.RemoteScheduler()
self.sch.update_resource("foo", 3)

def tearDown(self):
super(ConcurrentRunningResourcesTest, self).tearDown()

# graceful server shutdown
self._process.terminate()
self._process.join(timeout=1)
if self._process.is_alive():
os.kill(self._process.pid, signal.SIGKILL)

@contextmanager
def worker(self, scheduler=None, processes=2):
with luigi.worker.Worker(scheduler=scheduler or self.sch, worker_processes=processes) as w:
w._config.wait_interval = 0.2
w._config.check_unfulfilled_deps = False
yield w

@contextmanager
def assert_duration(self, min_duration=0, max_duration=-1):
t0 = time.time()
try:
yield
finally:
duration = time.time() - t0
self.assertGreater(duration, min_duration)
if max_duration > 0:
self.assertLess(duration, max_duration)

def test_tasks_serial(self):
# serial test
# run two tasks that do not reduce the "foo" resource
# as the total foo resource (3) is smaller than the requirement of two tasks (4),
# the scheduler is forced to run them serially which takes longer than 4 seconds
with self.worker() as w:
w.add(ResourceWrapperTask(reduce_foo=False))

with self.assert_duration(min_duration=4):
w.run()

def test_tasks_parallel(self):
# parallel test
# run two tasks and the first one lowers its requirement on the "foo" resource, so that
# the total "foo" resource (3) is sufficient to run both tasks in parallel shortly after
# the first task started, so the entire process should not exceed 4 seconds
with self.worker() as w:
w.add(ResourceWrapperTask(reduce_foo=True))

with self.assert_duration(max_duration=4):
w.run()