From f814a895180cb8c885dfe43201338383f8b1db13 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 7 Nov 2017 19:32:47 +0300 Subject: [PATCH 01/23] Visible parameter for luigi.Parameters --- luigi/parameter.py | 3 ++- luigi/task.py | 4 ++-- luigi/worker.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/luigi/parameter.py b/luigi/parameter.py index 7dcbe1f5a8..95522ad746 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -113,7 +113,7 @@ def run(self): _counter = 0 # non-atomically increasing counter used for ordering parameters. def __init__(self, default=_no_value, is_global=False, significant=True, description=None, - config_path=None, positional=True, always_in_help=False, batch_method=None): + config_path=None, positional=True, always_in_help=False, batch_method=None, visible=True): """ :param default: the default value for this parameter. This should match the type of the Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for @@ -150,6 +150,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip positional = False self.significant = significant # Whether different values for this parameter will differentiate otherwise equal tasks self.positional = positional + self.visible = visible self.description = description self.always_in_help = always_in_help diff --git a/luigi/task.py b/luigi/task.py index 1e10da38cb..d9818ec469 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -469,14 +469,14 @@ def from_str_params(cls, params_str): return cls(**kwargs) - def to_str_params(self, only_significant=False): + def to_str_params(self, only_significant=False, only_visible=False): """ Convert all parameters to a str->str hash. """ params_str = {} params = dict(self.get_params()) for param_name, param_value in six.iteritems(self.param_kwargs): - if (not only_significant) or params[param_name].significant: + if ((not only_significant) or params[param_name].significant) and ((not only_visible) or params[param_name].visible): params_str[param_name] = params[param_name].serialize(param_value) return params_str diff --git a/luigi/worker.py b/luigi/worker.py index d2270df0b4..28f8fa6192 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -775,7 +775,7 @@ def _add(self, task, is_complete): runnable=runnable, priority=task.priority, resources=task.process_resources(), - params=task.to_str_params(), + params=task.to_str_params(only_visible=True), family=task.task_family, module=task.task_module, batchable=task.batchable, From 9d8355027e3267105cb83bc03c246aabf1f70040 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 15 Nov 2017 00:10:40 +0300 Subject: [PATCH 02/23] Docs and test for visible param --- doc/parameters.rst | 5 +++ luigi/parameter.py | 4 ++ test/visible_parameters_test.py | 66 +++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 test/visible_parameters_test.py diff --git a/doc/parameters.rst b/doc/parameters.rst index 1e9f774416..fd66d33cab 100644 --- a/doc/parameters.rst +++ b/doc/parameters.rst @@ -88,6 +88,11 @@ are not the same instance: >>> hash(c) == hash(d) True +Invisible parameters + +If a parameter is created with ``visible=False``, +it is ignored in central scheduler Web-view. + Parameter types ^^^^^^^^^^^^^^^ diff --git a/luigi/parameter.py b/luigi/parameter.py index 95522ad746..cff14a9aae 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -140,6 +140,10 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip parameter values into a single value. Used when receiving batched parameter lists from the scheduler. See :ref:`batch_method` + + :param visible: specify ``False`` if the parameter should not be visible in central scheduler WEB-view. + Default: ``True`` + """ self._default = default self._batch_method = batch_method diff --git a/test/visible_parameters_test.py b/test/visible_parameters_test.py new file mode 100644 index 0000000000..e90ee06d8e --- /dev/null +++ b/test/visible_parameters_test.py @@ -0,0 +1,66 @@ +import luigi +from helpers import unittest + + +class TestTask1(luigi.Task): + param_one = luigi.Parameter(default='one', visible=False) + param_two = luigi.Parameter(default='two') + param_three = luigi.Parameter(default='three', visible=False) + param_four = luigi.Parameter(default='four', significant=False) + param_five = luigi.Parameter(default='five', visible=False, significant=False) + + +class TestTask2(luigi.Task): + param_one = luigi.Parameter(default='1', visible=False) + param_two = luigi.Parameter(default='2', visible=False) + param_three = luigi.Parameter(default='3', visible=False) + param_four = luigi.Parameter(default='4', visible=False) + param_five = luigi.Parameter(default='5', visible=False) + + +class TestTask3(luigi.Task): + param_one = luigi.Parameter(default='one') + param_two = luigi.Parameter(default='two') + param_three = luigi.Parameter(default='three') + param_four = luigi.Parameter(default='four', significant=False) + param_five = luigi.Parameter(default='five', significant=False) + + +class Test(unittest.TestCase): + def test_task_visible_vs_invisible(self): + task1 = TestTask1() + task3 = TestTask3() + + self.assertEqual(task1.to_str_params(), task3.to_str_params()) + + def test_task_visible_vs_invisible_using_only_significant(self): + task1 = TestTask1() + task3 = TestTask3() + + self.assertEqual(task1.to_str_params(only_significant=True), task3.to_str_params(only_significant=True)) + + def test_task_params(self): + task = TestTask1() + + self.assertEqual(str(task), 'TestTask1(param_one=one, param_two=two, param_three=three)') + + def test_similar_task_to_str_equality(self): + task1 = TestTask1() + task2 = TestTask1() + + self.assertEqual(task1.to_str_params(), task2.to_str_params()) + + def test_only_visible(self): + task = TestTask1() + + self.assertEqual(task.to_str_params(only_visible=True), {'param_two': 'two', 'param_four': 'four'}) + + def test_to_str(self): + task = TestTask2() + + self.assertEqual(task.to_str_params(), {'param_one': '1', 'param_two': '2', 'param_three': '3', 'param_four': '4', 'param_five': '5'}) + + def test_to_str_all_params_invisible(self): + task = TestTask2() + + self.assertEqual(task.to_str_params(only_visible=True), {}) \ No newline at end of file From 4059db0e4c6c460c599b1c9fb699f2bc3fc0b44a Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 29 Nov 2017 21:30:57 +0300 Subject: [PATCH 03/23] Hidden parameter logic --- luigi/db_task_history.py | 2 +- luigi/parameter.py | 7 ++++--- luigi/scheduler.py | 9 ++++++++- luigi/task.py | 9 +++++---- luigi/worker.py | 2 +- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/luigi/db_task_history.py b/luigi/db_task_history.py index 3f175c25c9..63d68dca37 100644 --- a/luigi/db_task_history.py +++ b/luigi/db_task_history.py @@ -124,7 +124,7 @@ def _find_or_create_task(self, task): else: task_record = TaskRecord(task_id=task._task.id, name=task.task_family, host=task.host) for (k, v) in six.iteritems(task.parameters): - task_record.parameters[k] = TaskParameter(name=k, value=v) + task_record.parameters[k] = TaskParameter(name=k, value=v[0]) session.add(task_record) yield (task_record, session) if task.host: diff --git a/luigi/parameter.py b/luigi/parameter.py index cff14a9aae..93713fc6bd 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -113,7 +113,8 @@ def run(self): _counter = 0 # non-atomically increasing counter used for ordering parameters. def __init__(self, default=_no_value, is_global=False, significant=True, description=None, - config_path=None, positional=True, always_in_help=False, batch_method=None, visible=True): + # config_path=None, positional=True, always_in_help=False, batch_method=None, visible=True): + config_path=None, positional=True, always_in_help=False, batch_method=None, visible=1): """ :param default: the default value for this parameter. This should match the type of the Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for @@ -154,7 +155,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip positional = False self.significant = significant # Whether different values for this parameter will differentiate otherwise equal tasks self.positional = positional - self.visible = visible + self.visible = visible # 1 - public 0 - hidden 2 - private self.description = description self.always_in_help = always_in_help @@ -257,7 +258,7 @@ def serialize(self, x): :param x: the value to serialize. """ - return str(x) + return str(x), self.visible def _warn_on_wrong_param_type(self, param_name, param_value): if self.__class__ != Parameter: diff --git a/luigi/scheduler.py b/luigi/scheduler.py index 47e53856ea..06c2fb8fb7 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -1218,6 +1218,13 @@ def _upstream_status(self, task_id, upstream_status_table): def _serialize_task(self, task_id, include_deps=True, deps=None): task = self._state.get_task(task_id) + + public_params = {} + + for param_name in task.params: + if task.params[param_name][1] == 1: + public_params[param_name] = task.params[param_name][0] + ret = { 'display_name': task.pretty_id, 'status': task.status, @@ -1226,7 +1233,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): 'time_running': getattr(task, "time_running", None), 'start_time': task.time, 'last_updated': getattr(task, "updated", task.time), - 'params': task.params, + 'params': public_params, 'name': task.family, 'priority': task.priority, 'resources': task.resources, diff --git a/luigi/task.py b/luigi/task.py index d9818ec469..0a9f94c82c 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -129,7 +129,7 @@ def task_id_str(task_family, params): param_hash = hashlib.md5(param_str.encode('utf-8')).hexdigest() param_summary = '_'.join(p[:TASK_ID_TRUNCATE_PARAMS] - for p in (params[p] for p in sorted(params)[:TASK_ID_INCLUDE_PARAMS])) + for p, visible in (params[p] for p in sorted(params)[:TASK_ID_INCLUDE_PARAMS])) param_summary = TASK_ID_INVALID_CHAR_REGEX.sub('_', param_summary) return '{}_{}_{}'.format(task_family, param_summary, param_hash[:TASK_ID_TRUNCATE_HASH]) @@ -469,14 +469,15 @@ def from_str_params(cls, params_str): return cls(**kwargs) - def to_str_params(self, only_significant=False, only_visible=False): + # def to_str_params(self, only_significant=False, only_visible=False): + def to_str_params(self, only_significant=False): """ Convert all parameters to a str->str hash. """ params_str = {} params = dict(self.get_params()) for param_name, param_value in six.iteritems(self.param_kwargs): - if ((not only_significant) or params[param_name].significant) and ((not only_visible) or params[param_name].visible): + if ((not only_significant) or params[param_name].significant) and params[param_name].visible != 2: params_str[param_name] = params[param_name].serialize(param_value) return params_str @@ -521,7 +522,7 @@ def __repr__(self): param_objs = dict(params) for param_name, param_value in param_values: if param_objs[param_name].significant: - repr_parts.append('%s=%s' % (param_name, param_objs[param_name].serialize(param_value))) + repr_parts.append('%s=%s' % (param_name, param_objs[param_name].serialize(param_value)[0])) task_str = '{}({})'.format(self.get_task_family(), ', '.join(repr_parts)) diff --git a/luigi/worker.py b/luigi/worker.py index 28f8fa6192..d2270df0b4 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -775,7 +775,7 @@ def _add(self, task, is_complete): runnable=runnable, priority=task.priority, resources=task.process_resources(), - params=task.to_str_params(only_visible=True), + params=task.to_str_params(), family=task.task_family, module=task.task_module, batchable=task.batchable, From debf4d30af1ae8d22224c6e7aa1e9e6cd69942d0 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 11 Dec 2017 21:57:04 +0300 Subject: [PATCH 04/23] Visible modifiers --- luigi/db_task_history.py | 2 +- luigi/parameter.py | 6 +++--- luigi/scheduler.py | 25 ++++++++++++++----------- luigi/task.py | 18 +++++++++++++----- luigi/worker.py | 4 ++++ test/visible_parameters_test.py | 32 ++++++++++++++++---------------- 6 files changed, 51 insertions(+), 36 deletions(-) diff --git a/luigi/db_task_history.py b/luigi/db_task_history.py index 63d68dca37..3f175c25c9 100644 --- a/luigi/db_task_history.py +++ b/luigi/db_task_history.py @@ -124,7 +124,7 @@ def _find_or_create_task(self, task): else: task_record = TaskRecord(task_id=task._task.id, name=task.task_family, host=task.host) for (k, v) in six.iteritems(task.parameters): - task_record.parameters[k] = TaskParameter(name=k, value=v[0]) + task_record.parameters[k] = TaskParameter(name=k, value=v) session.add(task_record) yield (task_record, session) if task.host: diff --git a/luigi/parameter.py b/luigi/parameter.py index 93713fc6bd..5e738fe3f9 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -114,7 +114,7 @@ def run(self): def __init__(self, default=_no_value, is_global=False, significant=True, description=None, # config_path=None, positional=True, always_in_help=False, batch_method=None, visible=True): - config_path=None, positional=True, always_in_help=False, batch_method=None, visible=1): + config_path=None, positional=True, always_in_help=False, batch_method=None, visible=0): """ :param default: the default value for this parameter. This should match the type of the Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for @@ -155,7 +155,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip positional = False self.significant = significant # Whether different values for this parameter will differentiate otherwise equal tasks self.positional = positional - self.visible = visible # 1 - public 0 - hidden 2 - private + self.visible = visible # 0 - public 1 - hidden 2 - private self.description = description self.always_in_help = always_in_help @@ -258,7 +258,7 @@ def serialize(self, x): :param x: the value to serialize. """ - return str(x), self.visible + return str(x) def _warn_on_wrong_param_type(self, param_name, param_value): if self.__class__ != Parameter: diff --git a/luigi/scheduler.py b/luigi/scheduler.py index 06c2fb8fb7..ac476063f4 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -275,7 +275,7 @@ def __eq__(self, other): class Task(object): def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None, - params=None, tracking_url=None, status_message=None, progress_percentage=None, retry_policy='notoptional'): + params=None, visibility=None, tracking_url=None, status_message=None, progress_percentage=None, retry_policy='notoptional'): self.id = task_id self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active) self.workers = OrderedSet() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active @@ -296,6 +296,10 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', self.family = family self.module = module self.params = _get_default(params, {}) + self.visibility = visibility + self.public_params = _get_default({key: self.params[key] for key in self.params if self.visibility[key] == 0}, {}) + + print("inside task", self.params, self.visibility) self.retry_policy = retry_policy self.failures = Failures(self.retry_policy.disable_window) @@ -335,7 +339,7 @@ def has_excessive_failures(self): @property def pretty_id(self): - param_str = ', '.join('{}={}'.format(key, value) for key, value in sorted(self.params.items())) + param_str = ', '.join('{}={}'.format(key, value) for key, value in sorted(self.params.items()) if self.visibility[key] == 0) return '{}({})'.format(self.family, param_str) @@ -770,7 +774,7 @@ def forgive_failures(self, task_id=None): @rpc_method() def add_task(self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, - priority=0, family='', module=None, params=None, + priority=0, family='', module=None, params=None, visibility=None, assistant=False, tracking_url=None, worker=None, batchable=None, batch_id=None, retry_policy_dict={}, owners=None, **kwargs): """ @@ -788,7 +792,7 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if worker.enabled: _default_task = self._make_task( task_id=task_id, status=PENDING, deps=deps, resources=resources, - priority=priority, family=family, module=module, params=params, + priority=priority, family=family, module=module, params=params, visibility=visibility ) else: _default_task = None @@ -805,6 +809,10 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, task.module = module if not task.params: task.params = _get_default(params, {}) + if not task.visibility: + task.visibility = _get_default(visibility, {}) + + print("inside scheduler", params, visibility) if batch_id is not None: task.batch_id = batch_id @@ -1219,12 +1227,6 @@ def _upstream_status(self, task_id, upstream_status_table): def _serialize_task(self, task_id, include_deps=True, deps=None): task = self._state.get_task(task_id) - public_params = {} - - for param_name in task.params: - if task.params[param_name][1] == 1: - public_params[param_name] = task.params[param_name][0] - ret = { 'display_name': task.pretty_id, 'status': task.status, @@ -1233,7 +1235,8 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): 'time_running': getattr(task, "time_running", None), 'start_time': task.time, 'last_updated': getattr(task, "updated", task.time), - 'params': public_params, + 'params': task.public_params, + 'visibility': task.visibility, 'name': task.family, 'priority': task.priority, 'resources': task.resources, diff --git a/luigi/task.py b/luigi/task.py index 0a9f94c82c..6fd3c515cb 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -129,7 +129,7 @@ def task_id_str(task_family, params): param_hash = hashlib.md5(param_str.encode('utf-8')).hexdigest() param_summary = '_'.join(p[:TASK_ID_TRUNCATE_PARAMS] - for p, visible in (params[p] for p in sorted(params)[:TASK_ID_INCLUDE_PARAMS])) + for p in (params[p] for p in sorted(params)[:TASK_ID_INCLUDE_PARAMS])) param_summary = TASK_ID_INVALID_CHAR_REGEX.sub('_', param_summary) return '{}_{}_{}'.format(task_family, param_summary, param_hash[:TASK_ID_TRUNCATE_HASH]) @@ -452,7 +452,7 @@ def _warn_on_wrong_param_types(self): params[param_name]._warn_on_wrong_param_type(param_name, param_value) @classmethod - def from_str_params(cls, params_str): + def from_str_params(cls, params_str, visibility): """ Creates an instance from a str->str hash. @@ -460,7 +460,7 @@ def from_str_params(cls, params_str): """ kwargs = {} for param_name, param in cls.get_params(): - if param_name in params_str: + if param_name in params_str and visibility[param_name] != 2: param_str = params_str[param_name] if isinstance(param_str, list): kwargs[param_name] = param._parse_list(param_str) @@ -469,7 +469,6 @@ def from_str_params(cls, params_str): return cls(**kwargs) - # def to_str_params(self, only_significant=False, only_visible=False): def to_str_params(self, only_significant=False): """ Convert all parameters to a str->str hash. @@ -482,6 +481,15 @@ def to_str_params(self, only_significant=False): return params_str + def params_visibilities(self, only_significant=False): + visibility = {} + params = dict(self.get_params()) + for param_name, param_value in six.iteritems(self.param_kwargs): + if ((not only_significant) or params[param_name].significant) and params[param_name].visible != 2: + visibility[param_name] = params[param_name].visible + + return visibility + def clone(self, cls=None, **kwargs): """ Creates a new instance from an existing instance where some of the args have changed. @@ -522,7 +530,7 @@ def __repr__(self): param_objs = dict(params) for param_name, param_value in param_values: if param_objs[param_name].significant: - repr_parts.append('%s=%s' % (param_name, param_objs[param_name].serialize(param_value)[0])) + repr_parts.append('%s=%s' % (param_name, param_objs[param_name].serialize(param_value))) task_str = '{}({})'.format(self.get_task_family(), ', '.join(repr_parts)) diff --git a/luigi/worker.py b/luigi/worker.py index d2270df0b4..f47c9aebb5 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -573,6 +573,7 @@ def _announce_scheduling_failure(self, task, expl): task_name=str(task), family=task.task_family, params=task.to_str_params(only_significant=True), + visibility=task.params_visibilities(only_significant=True), expl=expl, owners=task._owner_list(), ) @@ -776,6 +777,7 @@ def _add(self, task, is_complete): priority=task.priority, resources=task.process_resources(), params=task.to_str_params(), + visibility=task.params_visibilities(), family=task.task_family, module=task.task_module, batchable=task.batchable, @@ -838,6 +840,7 @@ def _get_work_task_id(self, get_work_response): module=get_work_response.get('task_module'), family=get_work_response['task_family'], params=task.to_str_params(), + visibility=task.params_visibilities(), status=RUNNING, batch_id=get_work_response['batch_id'], ) @@ -993,6 +996,7 @@ def _handle_next_task(self): resources=task.process_resources(), runnable=None, params=task.to_str_params(), + visibility=task.params_visibilities(), family=task.task_family, module=task.task_module, new_deps=new_deps, diff --git a/test/visible_parameters_test.py b/test/visible_parameters_test.py index e90ee06d8e..3616fb167b 100644 --- a/test/visible_parameters_test.py +++ b/test/visible_parameters_test.py @@ -3,27 +3,27 @@ class TestTask1(luigi.Task): - param_one = luigi.Parameter(default='one', visible=False) + param_one = luigi.Parameter(default='one', visible=2) param_two = luigi.Parameter(default='two') - param_three = luigi.Parameter(default='three', visible=False) - param_four = luigi.Parameter(default='four', significant=False) - param_five = luigi.Parameter(default='five', visible=False, significant=False) + param_three = luigi.Parameter(default='three', visible=2) + param_four = luigi.Parameter(default='four', significant=False, visible=1) + param_five = luigi.Parameter(default='five', visible=2, significant=False) class TestTask2(luigi.Task): - param_one = luigi.Parameter(default='1', visible=False) - param_two = luigi.Parameter(default='2', visible=False) - param_three = luigi.Parameter(default='3', visible=False) - param_four = luigi.Parameter(default='4', visible=False) - param_five = luigi.Parameter(default='5', visible=False) + param_one = luigi.Parameter(default='1', visible=2) + param_two = luigi.Parameter(default='2', visible=2) + param_three = luigi.Parameter(default='3', visible=2) + param_four = luigi.Parameter(default='4', visible=2) + param_five = luigi.Parameter(default='5', visible=2) class TestTask3(luigi.Task): - param_one = luigi.Parameter(default='one') + param_one = luigi.Parameter(default='one', visible=2) param_two = luigi.Parameter(default='two') - param_three = luigi.Parameter(default='three') - param_four = luigi.Parameter(default='four', significant=False) - param_five = luigi.Parameter(default='five', significant=False) + param_three = luigi.Parameter(default='three', visible=2) + param_four = luigi.Parameter(default='four', significant=False, visible=1) + param_five = luigi.Parameter(default='five', significant=False, visible=2) class Test(unittest.TestCase): @@ -53,14 +53,14 @@ def test_similar_task_to_str_equality(self): def test_only_visible(self): task = TestTask1() - self.assertEqual(task.to_str_params(only_visible=True), {'param_two': 'two', 'param_four': 'four'}) + self.assertEqual(task.to_str_params(), {'param_two': 'two', 'param_four': 'four'}) def test_to_str(self): task = TestTask2() - self.assertEqual(task.to_str_params(), {'param_one': '1', 'param_two': '2', 'param_three': '3', 'param_four': '4', 'param_five': '5'}) + self.assertEqual(task.to_str_params(), {}) def test_to_str_all_params_invisible(self): task = TestTask2() - self.assertEqual(task.to_str_params(only_visible=True), {}) \ No newline at end of file + self.assertEqual(task.to_str_params(), {}) \ No newline at end of file From 81b268a7f96335c1ca89ea1cece7164b84d1edf9 Mon Sep 17 00:00:00 2001 From: Nick Date: Sun, 17 Dec 2017 15:11:10 +0300 Subject: [PATCH 05/23] 3rd state of parameter visibility: hidden --- luigi/scheduler.py | 30 +++++++++++++++++------------- luigi/task.py | 12 ++++++------ luigi/worker.py | 7 ++----- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/luigi/scheduler.py b/luigi/scheduler.py index ac476063f4..7334c77958 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -275,7 +275,8 @@ def __eq__(self, other): class Task(object): def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None, - params=None, visibility=None, tracking_url=None, status_message=None, progress_percentage=None, retry_policy='notoptional'): + params=None, tracking_url=None, status_message=None, progress_percentage=None, retry_policy='notoptional', + public_params=None, hidden_params=None): self.id = task_id self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active) self.workers = OrderedSet() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active @@ -296,10 +297,9 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', self.family = family self.module = module self.params = _get_default(params, {}) - self.visibility = visibility - self.public_params = _get_default({key: self.params[key] for key in self.params if self.visibility[key] == 0}, {}) - print("inside task", self.params, self.visibility) + self.public_params = _get_default(public_params, {}) + self.hidden_params = _get_default(hidden_params, {}) self.retry_policy = retry_policy self.failures = Failures(self.retry_policy.disable_window) @@ -339,7 +339,7 @@ def has_excessive_failures(self): @property def pretty_id(self): - param_str = ', '.join('{}={}'.format(key, value) for key, value in sorted(self.params.items()) if self.visibility[key] == 0) + param_str = ', '.join('{}={}'.format(key, value) for key, value in sorted(self.params.items())) return '{}({})'.format(self.family, param_str) @@ -774,7 +774,7 @@ def forgive_failures(self, task_id=None): @rpc_method() def add_task(self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, - priority=0, family='', module=None, params=None, visibility=None, + priority=0, family='', module=None, params=None, assistant=False, tracking_url=None, worker=None, batchable=None, batch_id=None, retry_policy_dict={}, owners=None, **kwargs): """ @@ -789,10 +789,15 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, worker = self._update_worker(worker_id) retry_policy = self._generate_retry_policy(retry_policy_dict) + all_params = {key: params[key][0] for key in params} + public_params = {key: params[key][0] for key in params if params[key][1] == 0} + hidden_params = {key: params[key][0] for key in params if params[key][1] == 1} + if worker.enabled: _default_task = self._make_task( task_id=task_id, status=PENDING, deps=deps, resources=resources, - priority=priority, family=family, module=module, params=params, visibility=visibility + priority=priority, family=family, module=module, + params=all_params, public_params=public_params, hidden_params=hidden_params ) else: _default_task = None @@ -808,11 +813,11 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if not getattr(task, 'module', None): task.module = module if not task.params: - task.params = _get_default(params, {}) - if not task.visibility: - task.visibility = _get_default(visibility, {}) - - print("inside scheduler", params, visibility) + task.params = _get_default(all_params, {}) + if not task.public_params: + task.public_params = _get_default(public_params, {}) + if not task.hidden_params: + task.hidden_params = _get_default(hidden_params, {}) if batch_id is not None: task.batch_id = batch_id @@ -1236,7 +1241,6 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): 'start_time': task.time, 'last_updated': getattr(task, "updated", task.time), 'params': task.public_params, - 'visibility': task.visibility, 'name': task.family, 'priority': task.priority, 'resources': task.resources, diff --git a/luigi/task.py b/luigi/task.py index 6fd3c515cb..e8713a9f97 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -452,7 +452,7 @@ def _warn_on_wrong_param_types(self): params[param_name]._warn_on_wrong_param_type(param_name, param_value) @classmethod - def from_str_params(cls, params_str, visibility): + def from_str_params(cls, params_str): """ Creates an instance from a str->str hash. @@ -460,7 +460,7 @@ def from_str_params(cls, params_str, visibility): """ kwargs = {} for param_name, param in cls.get_params(): - if param_name in params_str and visibility[param_name] != 2: + if param_name in params_str: param_str = params_str[param_name] if isinstance(param_str, list): kwargs[param_name] = param._parse_list(param_str) @@ -481,14 +481,14 @@ def to_str_params(self, only_significant=False): return params_str - def params_visibilities(self, only_significant=False): - visibility = {} + def to_str_params_with_visibility(self, only_significant=False): + params_str_with_visibility = {} params = dict(self.get_params()) for param_name, param_value in six.iteritems(self.param_kwargs): if ((not only_significant) or params[param_name].significant) and params[param_name].visible != 2: - visibility[param_name] = params[param_name].visible + params_str_with_visibility[param_name] = (params[param_name].serialize(param_value), params[param_name].visible) - return visibility + return params_str_with_visibility def clone(self, cls=None, **kwargs): """ diff --git a/luigi/worker.py b/luigi/worker.py index f47c9aebb5..3a2fa4f7ea 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -573,7 +573,6 @@ def _announce_scheduling_failure(self, task, expl): task_name=str(task), family=task.task_family, params=task.to_str_params(only_significant=True), - visibility=task.params_visibilities(only_significant=True), expl=expl, owners=task._owner_list(), ) @@ -776,8 +775,8 @@ def _add(self, task, is_complete): runnable=runnable, priority=task.priority, resources=task.process_resources(), - params=task.to_str_params(), - visibility=task.params_visibilities(), + # params=task.to_str_params(), + params=task.to_str_params_with_visibility(), family=task.task_family, module=task.task_module, batchable=task.batchable, @@ -840,7 +839,6 @@ def _get_work_task_id(self, get_work_response): module=get_work_response.get('task_module'), family=get_work_response['task_family'], params=task.to_str_params(), - visibility=task.params_visibilities(), status=RUNNING, batch_id=get_work_response['batch_id'], ) @@ -996,7 +994,6 @@ def _handle_next_task(self): resources=task.process_resources(), runnable=None, params=task.to_str_params(), - visibility=task.params_visibilities(), family=task.task_family, module=task.task_module, new_deps=new_deps, From 2530c0b29fd1c7b6ec47db1795eeee67a4c13680 Mon Sep 17 00:00:00 2001 From: Nick Date: Sun, 17 Dec 2017 15:13:10 +0300 Subject: [PATCH 06/23] Update parameter.py --- luigi/parameter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/luigi/parameter.py b/luigi/parameter.py index 5e738fe3f9..e5c5305447 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -113,7 +113,6 @@ def run(self): _counter = 0 # non-atomically increasing counter used for ordering parameters. def __init__(self, default=_no_value, is_global=False, significant=True, description=None, - # config_path=None, positional=True, always_in_help=False, batch_method=None, visible=True): config_path=None, positional=True, always_in_help=False, batch_method=None, visible=0): """ :param default: the default value for this parameter. This should match the type of the From 944d1432a110bc1f08215abd27d8fecc3588e6cd Mon Sep 17 00:00:00 2001 From: Nick Date: Sun, 17 Dec 2017 15:13:45 +0300 Subject: [PATCH 07/23] Update worker.py --- luigi/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/luigi/worker.py b/luigi/worker.py index 3a2fa4f7ea..fb72e8a47f 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -775,7 +775,6 @@ def _add(self, task, is_complete): runnable=runnable, priority=task.priority, resources=task.process_resources(), - # params=task.to_str_params(), params=task.to_str_params_with_visibility(), family=task.task_family, module=task.task_module, From a3305527a7a3c2e79530291a25973b39553251a1 Mon Sep 17 00:00:00 2001 From: Nick Date: Sun, 17 Dec 2017 15:18:44 +0300 Subject: [PATCH 08/23] Update parameters.rst --- doc/parameters.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/doc/parameters.rst b/doc/parameters.rst index fd66d33cab..813d650dca 100644 --- a/doc/parameters.rst +++ b/doc/parameters.rst @@ -90,8 +90,10 @@ are not the same instance: Invisible parameters -If a parameter is created with ``visible=False``, -it is ignored in central scheduler Web-view. +``visible=0`` (default) - visible everywhere +``visible=1`` - ignored only in WEB-view +``visible=2`` - ignored in WEB-view, central scheduler and task_parameters history in databse + Parameter types ^^^^^^^^^^^^^^^ From 886179b44437ffef1f10804c95e29505ba93a77b Mon Sep 17 00:00:00 2001 From: Nick Date: Sun, 17 Jun 2018 18:50:45 +0300 Subject: [PATCH 09/23] Merge branch 'master' into parameter_visibility --- .gitignore | 4 +- .travis.yml | 3 + README.rst | 11 +- codecov.yml | 35 +- doc/command_line.rst | 38 -- doc/configuration.rst | 15 + doc/index.rst | 2 +- doc/luigi_patterns.rst | 63 +++ doc/parameters.rst | 2 +- doc/running_luigi.rst | 108 ++++ examples/dynamic_requirements.py | 4 +- examples/execution_summary_example.py | 4 +- examples/per_task_retry_policy.py | 2 +- luigi/__init__.py | 4 +- luigi/contrib/batch.py | 216 ++++++++ luigi/contrib/bigquery.py | 14 +- luigi/contrib/bigquery_avro.py | 87 +-- luigi/contrib/dataproc.py | 20 +- luigi/contrib/docker_runner.py | 6 +- luigi/contrib/ecs.py | 4 +- luigi/contrib/esindex.py | 5 +- luigi/contrib/gcp.py | 10 +- luigi/contrib/gcs.py | 47 +- luigi/contrib/hadoop.py | 15 +- luigi/contrib/hdfs/config.py | 11 +- luigi/contrib/hdfs/snakebite_client.py | 9 + luigi/contrib/hive.py | 2 +- luigi/contrib/kubernetes.py | 180 +++++-- luigi/contrib/mongodb.py | 2 +- luigi/contrib/opener.py | 5 +- luigi/contrib/redshift.py | 66 ++- luigi/contrib/s3.py | 553 +++++++++----------- luigi/contrib/sqla.py | 5 +- luigi/db_task_history.py | 43 +- luigi/execution_summary.py | 10 +- luigi/lock.py | 2 +- luigi/parameter.py | 28 + luigi/retcodes.py | 7 +- luigi/rpc.py | 13 +- luigi/scheduler.py | 118 ++++- luigi/static/visualiser/index.html | 45 +- luigi/static/visualiser/js/graph.js | 8 +- luigi/static/visualiser/js/luigi.js | 38 +- luigi/static/visualiser/js/visualiserApp.js | 112 +++- luigi/task.py | 21 +- luigi/templates/history.html | 2 +- luigi/tools/deps.py | 4 +- luigi/tools/range.py | 2 +- luigi/util.py | 27 +- luigi/worker.py | 84 ++- setup.py | 2 +- test/cmdline_test.py | 32 ++ test/contrib/batch_test.py | 154 ++++++ test/contrib/bigquery_gcloud_test.py | 6 +- test/contrib/bigquery_test.py | 8 + test/contrib/dataproc_test.py | 6 +- test/contrib/ecs_test.py | 8 +- test/contrib/gcs_test.py | 25 +- test/contrib/hdfs_test.py | 12 +- test/contrib/hive_test.py | 6 +- test/contrib/kubernetes_test.py | 1 + test/contrib/redshift_test.py | 289 +++++++++- test/contrib/s3_test.py | 252 ++++----- test/contrib/sqla_test.py | 2 +- test/date_interval_test.py | 2 +- test/date_parameter_test.py | 4 + test/decorator_test.py | 4 +- test/event_callbacks_test.py | 20 +- test/execution_summary_test.py | 26 +- test/local_target_test.py | 4 +- test/lock_test.py | 13 +- test/parameter_test.py | 58 +- test/redshift_test.py | 101 ---- test/rpc_test.py | 22 +- test/scheduler_api_test.py | 15 +- test/scheduler_message_test.py | 119 +++++ test/scheduler_test.py | 17 +- test/simulate_test.py | 2 +- test/snakebite_test.py | 26 + test/task_bulk_complete_test.py | 8 +- test/task_running_resources_test.py | 135 +++++ test/util_test.py | 98 +++- test/worker_multiprocess_test.py | 15 + tox.ini | 8 +- 84 files changed, 2645 insertions(+), 971 deletions(-) delete mode 100644 doc/command_line.rst create mode 100644 doc/running_luigi.rst create mode 100644 luigi/contrib/batch.py create mode 100644 test/contrib/batch_test.py delete mode 100644 test/redshift_test.py create mode 100644 test/scheduler_message_test.py create mode 100644 test/task_running_resources_test.py diff --git a/.gitignore b/.gitignore index b9db8c3e90..819fe424a1 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,10 @@ pig_property_file packages.tar +# Ignore the data files +data test/data +examples/data Vagrantfile @@ -23,7 +26,6 @@ Vagrantfile *.rej *.orig - # Created by https://www.gitignore.io ### Python ### diff --git a/.travis.yml b/.travis.yml index 300043610c..6afc73689b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,6 +15,9 @@ env: - BQ_TEST_PROJECT_ID=luigi-travistestenvironment - BQ_TEST_INPUT_BUCKET=luigi-bigquery-test - GOOGLE_APPLICATION_CREDENTIALS=test/gcloud-credentials.json + - AWS_DEFAULT_REGION=us-east-1 + - AWS_ACCESS_KEY_ID=accesskey + - AWS_SECRET_ACCESS_KEY=secretkey matrix: - TOXENV=flake8 - TOXENV=docs diff --git a/README.rst b/README.rst index 21294c05cc..3a76eb03a4 100644 --- a/README.rst +++ b/README.rst @@ -17,9 +17,9 @@ .. image:: https://img.shields.io/pypi/l/luigi.svg?style=flat :target: https://pypi.python.org/pypi/luigi -Luigi is a Python (2.7, 3.3, 3.4, 3.5) package that helps you build complex pipelines of batch -jobs. It handles dependency resolution, workflow management, visualization, -handling failures, command line integration, and much more. +Luigi is a Python (2.7, 3.3, 3.4, 3.5, 3.6) package that helps you build complex +pipelines of batch jobs. It handles dependency resolution, workflow management, +visualization, handling failures, command line integration, and much more. Getting Started --------------- @@ -147,6 +147,8 @@ or held presentations about Luigi: * `voyages-sncf.com `_ `(presentation, 2017) `__ * `Open Targets `_ `(blog, 2017) `__ * `Leipzig University Library `_ `(presentation, 2016) `__ / `(project) `__ +* `Synetiq `_ `(presentation, 2017) `__ +* `Glossier `_ `(blog, 2018) `__ Some more companies are using Luigi but haven't had a chance yet to write about it: @@ -160,6 +162,9 @@ Some more companies are using Luigi but haven't had a chance yet to write about * `Grovo `_ * `Weebly `_ * `Deloitte `_ +* `Stacktome `_ +* `LINX+Neemu+Chaordic `_ +* `Foxberry `_ We're more than happy to have your company added here. Just send a PR on GitHub. diff --git a/codecov.yml b/codecov.yml index 6a96d68172..587cc56f98 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,21 +1,32 @@ -# First just blindly copy paste what is default values from the docs page -# https://github.com/codecov/support/wiki/codecov.yml coverage: - precision: 2 - round: down - range: "70...100" + precision: 2 # Just copied from default + round: down # Just copied from default + range: "70...100" # Just copied from default status: project: + default: false # disable the default status that measures entire project + core: + target: 92% + paths: "luigi/*.py" + patch: # Just copied from default default: - target: auto if_no_uploads: error - patch: - default: - if_no_uploads: error - - changes: true + changes: true # Just copied from default + + ignore: + - "examples/" + - "luigi/tools" # These are tested as actual run commands without coverage + # List modules who's tests are not run by Travis or + # are run in a subprocesses (like on cluster). + - "luigi/contrib/gcs.py" + - "luigi/contrib/bigquery.py" + - "luigi/contrib/bigquery_avro.py" + - "luigi/contrib/hdfs/" + - "luigi/contrib/hadoop.py" + - "luigi/contrib/mrrunner.py" + - "luigi/contrib/kubernetes.py" -# But for luigi we do not want any comments +# For luigi we do not want any comments comment: false diff --git a/doc/command_line.rst b/doc/command_line.rst deleted file mode 100644 index 552aaba18f..0000000000 --- a/doc/command_line.rst +++ /dev/null @@ -1,38 +0,0 @@ -.. _CommandLine: - -Running from the Command Line -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The prefered way to run Luigi tasks is through the ``luigi`` command line tool -that will be installed with the pip package. - -.. code-block:: python - - # my_module.py, available in your sys.path - import luigi - - class MyTask(luigi.Task): - x = luigi.IntParameter() - y = luigi.IntParameter(default=45) - - def run(self): - print self.x + self.y - -Should be run like this - -.. code-block:: console - - $ luigi --module my_module MyTask --x 123 --y 456 --local-scheduler - -Or alternatively like this: - -.. code-block:: console - - $ python -m luigi --module my_module MyTask --x 100 --local-scheduler - -Note that if a parameter name contains '_', it should be replaced by '-'. -For example, if MyTask had a parameter called 'my_parameter': - -.. code-block:: console - - $ luigi --module my_module MyTask --my-parameter 100 --local-scheduler diff --git a/doc/configuration.rst b/doc/configuration.rst index cfa643528c..a00ddd5e85 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -270,6 +270,12 @@ check_unfulfilled_deps resource-intensive. Defaults to true. +force_multiprocessing + By default, luigi uses multiprocessing when *more than one* worker process is + requested. Whet set to true, multiprocessing is used independent of the + the number of workers. + Defaults to false. + [elasticsearch] --------------- @@ -716,6 +722,15 @@ worker_disconnect_delay scheduler before removing it and marking all of its running tasks as failed. Defaults to 60. +pause_enabled + If false, disables pause/unpause operations and hides the pause toggle from + the visualiser. + +send_messages + When true, the scheduler is allowed to send messages to running tasks and + the central scheduler provides a simple prompt per task to send messages. + Defaults to true. + [sendgrid] ---------- diff --git a/doc/index.rst b/doc/index.rst index a92bdf662a..477d189dbb 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -15,7 +15,7 @@ Table of Contents workflows.rst tasks.rst parameters.rst - command_line.rst + running_luigi.rst central_scheduler.rst execution_model.rst luigi_patterns.rst diff --git a/doc/luigi_patterns.rst b/doc/luigi_patterns.rst index 532b2d9966..e59d1a4b47 100644 --- a/doc/luigi_patterns.rst +++ b/doc/luigi_patterns.rst @@ -226,6 +226,33 @@ the task parameters or other dynamic attributes: Since, by default, resources have a usage limit of 1, no two instances of Task A will now run if they have the same `important_file_name` property. +Decreasing resources of running tasks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +At scheduling time, the luigi scheduler needs to be aware of the maximum +resource consumption a task might have once it runs. For some tasks, however, +it can be beneficial to decrease the amount of consumed resources between two +steps within their run method (e.g. after some heavy computation). In this +case, a different task waiting for that particular resource can already be +scheduled. + +.. code-block:: python + + class A(luigi.Task): + + # set maximum resources a priori + resources = {"some_resource": 3} + + def run(self): + # do something + ... + + # decrease consumption of "some_resource" by one + self.decrease_running_resources({"some_resource": 1}) + + # continue with reduced resources + ... + Monitoring task pipelines ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -290,3 +317,39 @@ built-in solutions. In the case of you're dealing with a file system :meth:`~luigi.target.FileSystemTarget.temporary_path`. For other targets, you should ensure that the way you're writing your final output directory is atomic. + +Sending messages to tasks +~~~~~~~~~~~~~~~~~~~~~~~~~ + +The central scheduler is able to send messages to particular tasks. When a running task accepts +messages, it can access a `multiprocessing.Queue `__ +object storing incoming messages. You can implement custom behavior to react and respond to +messages: + +.. code-block:: python + + class Example(luigi.Task): + + # common task setup + ... + + # configure the task to accept all incoming messages + accepts_messages = True + + def run(self): + # this example runs some loop and listens for the + # "terminate" message, and responds to all other messages + for _ in some_loop(): + # check incomming messages + if not self.scheduler_messages.empty(): + msg = self.scheduler_messages.get() + if msg.content == "terminate": + break + else: + msg.respond("unknown message") + + # finalize + ... + +Messages can be sent right from the scheduler UI which also displays responses (if any). Note that +this feature is only available when the scheduler is configured to send messages (see the :ref:`scheduler-config` config), and the task is configured to accept them. diff --git a/doc/parameters.rst b/doc/parameters.rst index 813d650dca..3a3059579d 100644 --- a/doc/parameters.rst +++ b/doc/parameters.rst @@ -25,7 +25,7 @@ i.e. .. code:: python d = DailyReport(datetime.date(2012, 5, 10)) - print d.date + print(d.date) will return the same date that the object was constructed with. Same goes if you invoke Luigi on the command line. diff --git a/doc/running_luigi.rst b/doc/running_luigi.rst new file mode 100644 index 0000000000..85d72f0508 --- /dev/null +++ b/doc/running_luigi.rst @@ -0,0 +1,108 @@ +.. _RunningLuigi: + +Running from the Command Line +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The prefered way to run Luigi tasks is through the ``luigi`` command line tool +that will be installed with the pip package. + +.. code-block:: python + + # my_module.py, available in your sys.path + import luigi + + class MyTask(luigi.Task): + x = luigi.IntParameter() + y = luigi.IntParameter(default=45) + + def run(self): + print(self.x + self.y) + +Should be run like this + +.. code-block:: console + + $ luigi --module my_module MyTask --x 123 --y 456 --local-scheduler + +Or alternatively like this: + +.. code-block:: console + + $ python -m luigi --module my_module MyTask --x 100 --local-scheduler + +Note that if a parameter name contains '_', it should be replaced by '-'. +For example, if MyTask had a parameter called 'my_parameter': + +.. code-block:: console + + $ luigi --module my_module MyTask --my-parameter 100 --local-scheduler + + +Running from Python code +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Another way to start tasks from Python code is using ``luigi.build(tasks, worker_scheduler_factory=None, **env_params)`` +from ``luigi.interface`` module. + +This way of running luigi tasks is useful if you want to get some dynamic parameters from another +source, such as database, or provide additional logic before you start tasks. + +One notable difference is that ``build`` defaults to not using the identical process lock. +If you want to change this behaviour, just pass ``no_lock=False``. + + +.. code-block:: python + + class MyTask1(luigi.Task): + x = luigi.IntParameter() + y = luigi.IntParameter(default=0) + + def run(self): + print(self.x + self.y) + + + class MyTask2(luigi.Task): + x = luigi.IntParameter() + y = luigi.IntParameter(default=1) + z = luigi.IntParameter(default=2) + + def run(self): + print(self.x * self.y * self.z) + + + if __name__ == '__main__': + luigi.build([MyTask1(x=10), MyTask2(x=15, z=3)]) + + +Also, it is possible to pass additional parameters to ``build`` such as host, port, workers and local_scheduler: + +.. code-block:: python + + if __name__ == '__main__': + luigi.build([MyTask1(x=1)], workers=5, local_scheduler=True) + +To achieve some special requirements you can pass to ``build`` your ``worker_scheduler_factory`` +which will return your worker and/or scheduler implementations: + +.. code-block:: python + + class MyWorker(Worker): + # some custom logic + + + class MyFactory(object): + def create_local_scheduler(self): + return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False) + + def create_remote_scheduler(self, url): + return rpc.RemoteScheduler(url) + + def create_worker(self, scheduler, worker_processes, assistant=False): + # return your worker instance + return MyWorker( + scheduler=scheduler, worker_processes=worker_processes, assistant=assistant) + + + if __name__ == '__main__': + luigi.build([MyTask1(x=1)], worker_scheduler_factory=MyFactory()) + +In some cases (like task queue) it may be useful. diff --git a/examples/dynamic_requirements.py b/examples/dynamic_requirements.py index 1895569f61..ed2feba81f 100644 --- a/examples/dynamic_requirements.py +++ b/examples/dynamic_requirements.py @@ -21,7 +21,7 @@ import luigi -class Config(luigi.Task): +class Configuration(luigi.Task): seed = luigi.IntParameter() def output(self): @@ -78,7 +78,7 @@ def output(self): def run(self): # This could be done using regular requires method - config = self.clone(Config) + config = self.clone(Configuration) yield config with config.output().open() as f: diff --git a/examples/execution_summary_example.py b/examples/execution_summary_example.py index 00e5901033..d0a39b8fe4 100644 --- a/examples/execution_summary_example.py +++ b/examples/execution_summary_example.py @@ -28,14 +28,14 @@ ===== Luigi Execution Summary ===== Scheduled 218 tasks of which: - * 195 present dependencies were encountered: + * 195 complete ones were encountered: - 195 examples.Bar(num=5...199) * 1 ran successfully: - 1 examples.Boom(...) * 22 were left pending, among these: * 1 were missing external dependencies: - 1 MyExternal() - * 21 had missing external dependencies: + * 21 had missing dependencies: - 1 examples.EntryPoint() - examples.Foo(num=100, num2=16) and 9 other examples.Foo - 10 examples.DateTask(date=1998-03-23...1998-04-01, num=5) diff --git a/examples/per_task_retry_policy.py b/examples/per_task_retry_policy.py index 797dd4d843..ca0f560e23 100644 --- a/examples/per_task_retry_policy.py +++ b/examples/per_task_retry_policy.py @@ -40,7 +40,7 @@ - 1 DynamicErrorTaskSubmitter() * 1 had failed dependencies: - 1 examples.PerTaskRetryPolicy() - * 1 had missing external dependencies: + * 1 had missing dependencies: - 1 examples.PerTaskRetryPolicy() * 1 was not granted run permission by the scheduler: - 1 DynamicErrorTaskSubmitter() diff --git a/luigi/__init__.py b/luigi/__init__.py index cacc619475..2e858b18e4 100644 --- a/luigi/__init__.py +++ b/luigi/__init__.py @@ -36,7 +36,7 @@ DateIntervalParameter, TimeDeltaParameter, IntParameter, FloatParameter, BoolParameter, TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter, - NumericalParameter, ChoiceParameter + NumericalParameter, ChoiceParameter, OptionalParameter ) from luigi import configuration @@ -59,5 +59,5 @@ 'FloatParameter', 'BoolParameter', 'TaskParameter', 'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter', 'configuration', 'interface', 'local_target', 'run', 'build', 'event', 'Event', - 'NumericalParameter', 'ChoiceParameter' + 'NumericalParameter', 'ChoiceParameter', 'OptionalParameter' ] diff --git a/luigi/contrib/batch.py b/luigi/contrib/batch.py new file mode 100644 index 0000000000..ae29b7e1f9 --- /dev/null +++ b/luigi/contrib/batch.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018 Outlier Bio, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +AWS Batch wrapper for Luigi + +From the AWS website: + + AWS Batch enables you to run batch computing workloads on the AWS Cloud. + + Batch computing is a common way for developers, scientists, and engineers + to access large amounts of compute resources, and AWS Batch removes the + undifferentiated heavy lifting of configuring and managing the required + infrastructure. AWS Batch is similar to traditional batch computing + software. This service can efficiently provision resources in response to + jobs submitted in order to eliminate capacity constraints, reduce compute + costs, and deliver results quickly. + +See `AWS Batch User Guide`_ for more details. + +To use AWS Batch, you create a jobDefinition JSON that defines a `docker run`_ +command, and then submit this JSON to the API to queue up the task. Behind the +scenes, AWS Batch auto-scales a fleet of EC2 Container Service instances, +monitors the load on these instances, and schedules the jobs. + +This `boto3-powered`_ wrapper allows you to create Luigi Tasks to submit Batch +``jobDefinition``s. You can either pass a dict (mapping directly to the +``jobDefinition`` JSON) OR an Amazon Resource Name (arn) for a previously +registered ``jobDefinition``. + +Requires: + +- boto3 package +- Amazon AWS credentials discoverable by boto3 (e.g., by using ``aws configure`` + from awscli_) +- An enabled AWS Batch job queue configured to run on a compute environment. + +Written and maintained by Jake Feala (@jfeala) for Outlier Bio (@outlierbio) + +.. _`docker run`: https://docs.docker.com/reference/commandline/run +.. _jobDefinition: http://http://docs.aws.amazon.com/batch/latest/userguide/job_definitions.html +.. _`boto3-powered`: https://boto3.readthedocs.io +.. _awscli: https://aws.amazon.com/cli +.. _`AWS Batch User Guide`: http://docs.aws.amazon.com/AmazonECS/latest/developerguide/ECS_GetStarted.html + +""" + +import json +import logging +import random +import string +import time + +import luigi +logger = logging.getLogger(__name__) + +try: + import boto3 +except ImportError: + logger.warning('boto3 is not installed. BatchTasks require boto3') + + +class BatchJobException(Exception): + pass + + +POLL_TIME = 10 + + +def _random_id(): + return 'batch-job-' + ''.join(random.sample(string.ascii_lowercase, 8)) + + +class BatchClient(object): + + def __init__(self, poll_time=POLL_TIME): + self.poll_time = poll_time + self._client = boto3.client('batch') + self._log_client = boto3.client('logs') + self._queue = self.get_active_queue() + + def get_active_queue(self): + """Get name of first active job queue""" + + # Get dict of active queues keyed by name + queues = {q['jobQueueName']: q for q in self._client.describe_job_queues()['jobQueues'] + if q['state'] == 'ENABLED' and q['status'] == 'VALID'} + if not queues: + raise Exception('No job queues with state=ENABLED and status=VALID') + + # Pick the first queue as default + return list(queues.keys())[0] + + def get_job_id_from_name(self, job_name): + """Retrieve the first job ID matching the given name""" + jobs = self._client.list_jobs(jobQueue=self._queue, jobStatus='RUNNING')['jobSummaryList'] + matching_jobs = [job for job in jobs if job['jobName'] == job_name] + if matching_jobs: + return matching_jobs[0]['jobId'] + + def get_job_status(self, job_id): + """Retrieve task statuses from ECS API + + :param job_id (str): AWS Batch job uuid + + Returns one of {SUBMITTED|PENDING|RUNNABLE|STARTING|RUNNING|SUCCEEDED|FAILED} + """ + response = self._client.describe_jobs(jobs=[job_id]) + + # Error checking + status_code = response['ResponseMetadata']['HTTPStatusCode'] + if status_code != 200: + msg = 'Job status request received status code {0}:\n{1}' + raise Exception(msg.format(status_code, response)) + + return response['jobs'][0]['status'] + + def get_logs(self, log_stream_name, get_last=50): + """Retrieve log stream from CloudWatch""" + response = self._log_client.get_log_events( + logGroupName='/aws/batch/job', + logStreamName=log_stream_name, + startFromHead=False) + events = response['events'] + return '\n'.join(e['message'] for e in events[-get_last:]) + + def submit_job(self, job_definition, parameters, job_name=None, queue=None): + """Wrap submit_job with useful defaults""" + if job_name is None: + job_name = _random_id() + response = self._client.submit_job( + jobName=job_name, + jobQueue=queue or self.get_active_queue(), + jobDefinition=job_definition, + parameters=parameters + ) + return response['jobId'] + + def wait_on_job(self, job_id): + """Poll task status until STOPPED""" + + while True: + status = self.get_job_status(job_id) + if status == 'SUCCEEDED': + logger.info('Batch job {} SUCCEEDED'.format(job_id)) + return True + elif status == 'FAILED': + # Raise and notify if job failed + jobs = self._client.describe_jobs(jobs=[job_id])['jobs'] + job_str = json.dumps(jobs, indent=4) + logger.debug('Job details:\n' + job_str) + + log_stream_name = jobs[0]['attempts'][0]['container']['logStreamName'] + logs = self.get_logs(log_stream_name) + raise BatchJobException('Job {} failed: {}'.format( + job_id, logs)) + + time.sleep(self.poll_time) + logger.debug('Batch job status for job {0}: {1}'.format( + job_id, status)) + + def register_job_definition(self, json_fpath): + """Register a job definition with AWS Batch, using a JSON""" + with open(json_fpath) as f: + job_def = json.load(f) + response = self._client.register_job_definition(**job_def) + status_code = response['ResponseMetadata']['HTTPStatusCode'] + if status_code != 200: + msg = 'Register job definition request received status code {0}:\n{1}' + raise Exception(msg.format(status_code, response)) + return response + + +class BatchTask(luigi.Task): + + """ + Base class for an Amazon Batch job + + Amazon Batch requires you to register "job definitions", which are JSON + descriptions for how to issue the ``docker run`` command. This Luigi Task + requires a pre-registered Batch jobDefinition name passed as a Parameter + + :param job_definition (str): name of pre-registered jobDefinition + :param job_name: name of specific job, for tracking in the queue and logs. + + """ + job_definition = luigi.Parameter() + job_name = luigi.OptionalParameter(default=None) + poll_time = luigi.IntParameter(default=POLL_TIME) + + def run(self): + bc = BatchClient(self.poll_time) + job_id = bc.submit_job( + self.job_definition, + self.parameters, + job_name=self.job_name) + bc.wait_on_job(job_id) + + @property + def parameters(self): + """Override to return a dict of parameters for the Batch Task""" + return {} diff --git a/luigi/contrib/bigquery.py b/luigi/contrib/bigquery.py index 79fa180a26..b7937eb6cf 100644 --- a/luigi/contrib/bigquery.py +++ b/luigi/contrib/bigquery.py @@ -126,7 +126,7 @@ def __init__(self, oauth_credentials=None, descriptor='', http_=None): if descriptor: self.client = discovery.build_from_document(descriptor, **authenticate_kwargs) else: - self.client = discovery.build('bigquery', 'v2', **authenticate_kwargs) + self.client = discovery.build('bigquery', 'v2', cache_discovery=False, **authenticate_kwargs) def dataset_exists(self, dataset): """Returns whether the given dataset exists. @@ -174,7 +174,7 @@ def table_exists(self, table): return True - def make_dataset(self, dataset, raise_if_exists=False, body={}): + def make_dataset(self, dataset, raise_if_exists=False, body=None): """Creates a new dataset with the default permissions. :param dataset: @@ -183,6 +183,9 @@ def make_dataset(self, dataset, raise_if_exists=False, body={}): :raises luigi.target.FileAlreadyExists: if raise_if_exists=True and the dataset exists """ + if body is None: + body = {} + try: body['id'] = '{}:{}'.format(dataset.project_id, dataset.dataset_id) if dataset.location is not None: @@ -417,14 +420,13 @@ class MixinBigQueryBulkComplete(object): @classmethod def bulk_complete(cls, parameter_tuples): - if len(parameter_tuples) < 1: - return - # Instantiate the tasks to inspect them tasks_with_params = [(cls(p), p) for p in parameter_tuples] + if not tasks_with_params: + return # Grab the set of BigQuery datasets we are interested in - datasets = set([t.output().table.dataset for t, p in tasks_with_params]) + datasets = {t.output().table.dataset for t, p in tasks_with_params} logger.info('Checking datasets %s for available tables', datasets) # Query the available tables for all datasets diff --git a/luigi/contrib/bigquery_avro.py b/luigi/contrib/bigquery_avro.py index 975d7d9b04..043c8c3a34 100644 --- a/luigi/contrib/bigquery_avro.py +++ b/luigi/contrib/bigquery_avro.py @@ -19,11 +19,8 @@ class BigQueryLoadAvro(BigQueryLoadTask): """A helper for loading specifically Avro data into BigQuery from GCS. - Additional goodies - takes field documentation from the input data and propagates it - to BigQuery table description and field descriptions. Supports the following Avro schema - types: Primitives, Enums, Records, Arrays, Unions, and Maps. For Map schemas nested maps - and unions are not supported. For Union Schemas only nested Primitive and Record Schemas - are currently supported. + Copies table level description from Avro schema doc, BigQuery internally will copy field-level descriptions + to the table. Suitable for use via subclassing: override requires() to return Task(s) that output to GCS Targets; their paths are expected to be URIs of .avro files or URI prefixes @@ -41,7 +38,7 @@ def source_uris(self): return [self._avro_uri(x) for x in flatten(self.input())] def _get_input_schema(self): - '''Arbitrarily picks an object in input and reads the Avro schema from it.''' + """Arbitrarily picks an object in input and reads the Avro schema from it.""" assert avro, 'avro module required' input_target = flatten(self.input())[0] @@ -81,84 +78,8 @@ def _set_output_doc(self, avro_schema): bq_client = self.output().client.client table = self.output().table - current_bq_schema = bq_client.tables().get(projectId=table.project_id, - datasetId=table.dataset_id, - tableId=table.table_id).execute() - - def get_fields_with_description(bq_fields, avro_fields): - new_fields = [] - for field in bq_fields: - avro_field = avro_fields[field[u'name']] - field_type = type(avro_field.type) - - # Primitive Support - if field_type is avro.schema.PrimitiveSchema: - field[u'description'] = avro_field.doc - - # Enum Support - if field_type is avro.schema.EnumSchema: - field[u'description'] = avro_field.type.doc - - # Record Support - if field_type is avro.schema.RecordSchema: - field[u'description'] = avro_field.type.doc - field[u'fields'] = get_fields_with_description(field[u'fields'], avro_field.type.fields_dict) - - # Array Support - if field_type is avro.schema.ArraySchema: - field[u'description'] = avro_field.type.items.doc - field[u'fields'] = get_fields_with_description(field[u'fields'], avro_field.type.items.fields_dict) - - # Union Support - if type(avro_field.type) is avro.schema.UnionSchema: - for schema in avro_field.type.schemas: - if type(schema) is avro.schema.PrimitiveSchema: - field[u'description'] = avro_field.doc - - if type(schema) is avro.schema.RecordSchema: - field[u'description'] = schema.doc - field[u'fields'] = get_fields_with_description(field[u'fields'], schema.fields_dict) - - # Support for Enums, Arrays, Maps, and Unions inside of a union is not yet implemented - - # Map Support - if field_type is avro.schema.MapSchema: - field[u'description'] = avro_field.doc - - # Big Query Avro loader creates artificial key and value attributes in the Big Query schema - # ignoring the key and operating directly on the value - # https://cloud.google.com/bigquery/data-formats#avro_format - bq_map_value_field = field[u'fields'][-1] - avro_map_value = avro_field.type.values - value_field_type = type(avro_map_value) - - # Primitive Support: Unfortunately the map element doesn't directly have a doc attribute - # so there is no way to get documentation on the primitive types for the value attribute - - if value_field_type is avro.schema.EnumSchema: - bq_map_value_field[u'description'] = avro_map_value.type.doc - - if value_field_type is avro.schema.RecordSchema: - # Set values description using type's doc - bq_map_value_field[u'description'] = avro_map_value.doc - - # This is jumping into the map value directly and working with that - bq_map_value_field[u'fields'] = get_fields_with_description(bq_map_value_field[u'fields'], avro_map_value.fields_dict) - - if value_field_type is avro.schema.ArraySchema: - bq_map_value_field[u'description'] = avro_map_value.items.doc - bq_map_value_field[u'fields'] = get_fields_with_description(bq_map_value_field[u'fields'], avro_map_value.items.fields_dict) - - # Support for unions and maps nested inside of a map is not yet implemented - - new_fields.append(field) - return new_fields - - field_descriptions = get_fields_with_description(current_bq_schema['schema']['fields'], avro_schema.fields_dict) - patch = { 'description': avro_schema.doc, - 'schema': {'fields': field_descriptions, }, } bq_client.tables().patch(projectId=table.project_id, @@ -174,4 +95,4 @@ def run(self): try: self._set_output_doc(self._get_input_schema()) except Exception as e: - logger.warning('Could not propagate Avro doc to BigQuery table field descriptions: %r', e) + logger.warning('Could not propagate Avro doc to BigQuery table description: %r', e) diff --git a/luigi/contrib/dataproc.py b/luigi/contrib/dataproc.py index e003a122f7..a20ab7abd9 100644 --- a/luigi/contrib/dataproc.py +++ b/luigi/contrib/dataproc.py @@ -11,15 +11,15 @@ _dataproc_client = None try: - import oauth2client.client + import google.auth from googleapiclient import discovery from googleapiclient.errors import HttpError - DEFAULT_CREDENTIALS = oauth2client.client.GoogleCredentials.get_application_default() + DEFAULT_CREDENTIALS, _ = google.auth.default() authenticate_kwargs = gcp.get_authenticate_kwargs(DEFAULT_CREDENTIALS) - _dataproc_client = discovery.build('dataproc', 'v1', **authenticate_kwargs) + _dataproc_client = discovery.build('dataproc', 'v1', cache_discovery=False, **authenticate_kwargs) except ImportError: - logger.warning("Loading Dataproc module without the python packages googleapiclient & oauth2client. \ + logger.warning("Loading Dataproc module without the python packages googleapiclient & google-auth. \ This will crash at runtime if Dataproc functionality is used.") @@ -56,7 +56,11 @@ def submit_job(self, job_config): self._job_id = self._job['reference']['jobId'] return self._job - def submit_spark_job(self, jars, main_class, job_args=[]): + def submit_spark_job(self, jars, main_class, job_args=None): + + if job_args is None: + job_args = [] + job_config = {"job": { "placement": { "clusterName": self.dataproc_cluster_name @@ -72,7 +76,11 @@ def submit_spark_job(self, jars, main_class, job_args=[]): logger.info("Submitted new dataproc job:{} id:{}".format(self._job_name, self._job_id)) return self._job - def submit_pyspark_job(self, job_file, extra_files=[], job_args=[]): + def submit_pyspark_job(self, job_file, extra_files=list(), job_args=None): + + if job_args is None: + job_args = [] + job_config = {"job": { "placement": { "clusterName": self.dataproc_cluster_name diff --git a/luigi/contrib/docker_runner.py b/luigi/contrib/docker_runner.py index fc2af76a08..92be463672 100644 --- a/luigi/contrib/docker_runner.py +++ b/luigi/contrib/docker_runner.py @@ -205,6 +205,10 @@ def run(self): self._client.start(container['Id']) exit_status = self._client.wait(container['Id']) + # docker-py>=3.0.0 returns a dict instead of the status code directly + if type(exit_status) is dict: + exit_status = exit_status['StatusCode'] + if exit_status != 0: stdout = False stderr = True @@ -227,7 +231,7 @@ def run(self): container_name = self.name try: message = e.message - except: + except AttributeError: message = str(e) self.__logger.error("Container " + container_name + " exited with non zero code: " + message) diff --git a/luigi/contrib/ecs.py b/luigi/contrib/ecs.py index 8afe192865..f563e73dc0 100644 --- a/luigi/contrib/ecs.py +++ b/luigi/contrib/ecs.py @@ -133,8 +133,8 @@ class ECSTask(luigi.Task): """ - task_def_arn = luigi.Parameter(default=None) - task_def = luigi.Parameter(default=None) + task_def_arn = luigi.OptionalParameter(default=None) + task_def = luigi.OptionalParameter(default=None) cluster = luigi.Parameter(default='default') @property diff --git a/luigi/contrib/esindex.py b/luigi/contrib/esindex.py index ede6429f61..5a550d04d6 100644 --- a/luigi/contrib/esindex.py +++ b/luigi/contrib/esindex.py @@ -117,7 +117,7 @@ class ElasticsearchTarget(luigi.Target): def __init__(self, host, port, index, doc_type, update_id, marker_index_hist_size=0, http_auth=None, timeout=10, - extra_elasticsearch_args={}): + extra_elasticsearch_args=None): """ :param host: Elasticsearch server host :type host: str @@ -136,6 +136,9 @@ def __init__(self, host, port, index, doc_type, update_id, :param extra_elasticsearch_args: extra args for Elasticsearch :type Extra: dict """ + if extra_elasticsearch_args is None: + extra_elasticsearch_args = {} + self.host = host self.port = port self.http_auth = http_auth diff --git a/luigi/contrib/gcp.py b/luigi/contrib/gcp.py index 140b6b541e..bb768b94f6 100644 --- a/luigi/contrib/gcp.py +++ b/luigi/contrib/gcp.py @@ -6,9 +6,9 @@ try: import httplib2 - import oauth2client + import google.auth except ImportError: - logger.warning("Loading GCP module without the python packages httplib2, oauth2client. \ + logger.warning("Loading GCP module without the python packages httplib2, google-auth. \ This *could* crash at runtime if no other credentials are provided.") @@ -33,11 +33,11 @@ def get_authenticate_kwargs(oauth_credentials=None, http_=None): # neither http_ or credentials provided try: # try default credentials - oauth_credentials = oauth2client.client.GoogleCredentials.get_application_default() + credentials, _ = google.auth.default() authenticate_kwargs = { - "credentials": oauth_credentials + "credentials": credentials } - except oauth2client.client.GoogleCredentials.ApplicationDefaultCredentialsError: + except google.auth.exceptions.DefaultCredentialsError: # try http using httplib2 authenticate_kwargs = { "http": httplib2.Http() diff --git a/luigi/contrib/gcs.py b/luigi/contrib/gcs.py index 7a60b509b5..12d6710e57 100644 --- a/luigi/contrib/gcs.py +++ b/luigi/contrib/gcs.py @@ -42,7 +42,7 @@ from googleapiclient import discovery from googleapiclient import http except ImportError: - logger.warning("Loading GCS module without the python packages googleapiclient & oauth2client. \ + logger.warning("Loading GCS module without the python packages googleapiclient & google-auth. \ This will crash at runtime if GCS functionality is used.") else: # Retry transport and file IO errors. @@ -89,9 +89,9 @@ class GCSClient(luigi.target.FileSystem): There are several ways to use this class. By default it will use the app default credentials, as described at https://developers.google.com/identity/protocols/application-default-credentials . - Alternatively, you may pass an oauth2client credentials object. e.g. to use a service account:: + Alternatively, you may pass an google-auth credentials object. e.g. to use a service account:: - credentials = oauth2client.client.SignedJwtAssertionCredentials( + credentials = google.auth.jwt.Credentials.from_service_account_info( '012345678912-ThisIsARandomServiceAccountEmail@developer.gserviceaccount.com', 'These are the contents of the p12 file that came with the service account', scope='https://www.googleapis.com/auth/devstorage.read_write') @@ -108,14 +108,18 @@ class GCSClient(luigi.target.FileSystem): as the ``descriptor`` argument. """ def __init__(self, oauth_credentials=None, descriptor='', http_=None, - chunksize=CHUNKSIZE): + chunksize=CHUNKSIZE, **discovery_build_kwargs): self.chunksize = chunksize authenticate_kwargs = gcp.get_authenticate_kwargs(oauth_credentials, http_) + build_kwargs = authenticate_kwargs.copy() + build_kwargs.update(discovery_build_kwargs) + if descriptor: - self.client = discovery.build_from_document(descriptor, **authenticate_kwargs) + self.client = discovery.build_from_document(descriptor, **build_kwargs) else: - self.client = discovery.build('storage', 'v1', **authenticate_kwargs) + build_kwargs.setdefault('cache_discovery', False) + self.client = discovery.build('storage', 'v1', **build_kwargs) def _path_to_bucket_and_key(self, path): (scheme, netloc, path, _, _) = urlsplit(path) @@ -245,10 +249,39 @@ def put(self, filename, dest_path, mimetype=None, chunksize=None): resumable = os.path.getsize(filename) > 0 mimetype = mimetype or mimetypes.guess_type(dest_path)[0] or DEFAULT_MIMETYPE - media = http.MediaFileUpload(filename, mimetype, chunksize=chunksize, resumable=resumable) + media = http.MediaFileUpload(filename, mimetype=mimetype, chunksize=chunksize, resumable=resumable) self._do_put(media, dest_path) + def _forward_args_to_put(self, kwargs): + return self.put(**kwargs) + + def put_multiple(self, filepaths, remote_directory, mimetype=None, chunksize=None, num_process=1): + if isinstance(filepaths, str): + raise ValueError( + 'filenames must be a list of strings. If you want to put a single file, ' + 'use the `put(self, filename, ...)` method' + ) + + put_kwargs_list = [ + { + 'filename': filepath, + 'dest_path': os.path.join(remote_directory, os.path.basename(filepath)), + 'mimetype': mimetype, + 'chunksize': chunksize, + } + for filepath in filepaths + ] + + if num_process > 1: + from multiprocessing import Pool + from contextlib import closing + with closing(Pool(num_process)) as p: + return p.map(self._forward_args_to_put, put_kwargs_list) + else: + for put_kwargs in put_kwargs_list: + self._forward_args_to_put(put_kwargs) + def put_string(self, contents, dest_path, mimetype=None): mimetype = mimetype or mimetypes.guess_type(dest_path)[0] or DEFAULT_MIMETYPE assert isinstance(mimetype, six.string_types) diff --git a/luigi/contrib/hadoop.py b/luigi/contrib/hadoop.py index 46c14f9a6a..f92f0ef725 100644 --- a/luigi/contrib/hadoop.py +++ b/luigi/contrib/hadoop.py @@ -72,10 +72,13 @@ class hadoop(luigi.task.Config): - pool = luigi.Parameter(default=None, - description='Hadoop pool so use for Hadoop tasks. ' - 'To specify pools per tasks, see ' - 'BaseHadoopJobTask.pool') + pool = luigi.OptionalParameter( + default=None, + description=( + 'Hadoop pool so use for Hadoop tasks. To specify pools per tasks, ' + 'see BaseHadoopJobTask.pool' + ), + ) def attach(*packages): @@ -668,7 +671,7 @@ def run_job(self, job): class BaseHadoopJobTask(luigi.Task): - pool = luigi.Parameter(default=None, significant=False, positional=False) + pool = luigi.OptionalParameter(default=None, significant=False, positional=False) # This value can be set to change the default batching increment. Default is 1 for backwards compatibility. batch_counter_default = 1 @@ -850,7 +853,7 @@ def writer(self, outputs, stdout, stderr=sys.stderr): # JSON is already serialized, so we put `self.serialize` in a else statement. output = map(self.serialize, output) print("\t".join(output), file=stdout) - except: + except BaseException: print(output, file=stderr) raise diff --git a/luigi/contrib/hdfs/config.py b/luigi/contrib/hdfs/config.py index d2bf22a931..83e1c790c3 100644 --- a/luigi/contrib/hdfs/config.py +++ b/luigi/contrib/hdfs/config.py @@ -36,18 +36,19 @@ class hdfs(luigi.Config): client_version = luigi.IntParameter(default=None) - effective_user = luigi.Parameter( + effective_user = luigi.OptionalParameter( default=os.getenv('HADOOP_USER_NAME'), description="Optionally specifies the effective user for snakebite. " "If not set the environment variable HADOOP_USER_NAME is " "used, else USER") snakebite_autoconfig = luigi.BoolParameter(default=False) - namenode_host = luigi.Parameter(default=None) + namenode_host = luigi.OptionalParameter(default=None) namenode_port = luigi.IntParameter(default=None) client = luigi.Parameter(default='hadoopcli') - tmp_dir = luigi.Parameter(default=None, - config_path=dict(section='core', name='hdfs-tmp-dir'), - ) + tmp_dir = luigi.OptionalParameter( + default=None, + config_path=dict(section='core', name='hdfs-tmp-dir'), + ) class hadoopcli(luigi.Config): diff --git a/luigi/contrib/hdfs/snakebite_client.py b/luigi/contrib/hdfs/snakebite_client.py index 0988cf2a5b..5c38787f4f 100644 --- a/luigi/contrib/hdfs/snakebite_client.py +++ b/luigi/contrib/hdfs/snakebite_client.py @@ -222,6 +222,15 @@ def get(self, path, local_destination): return list(self.get_bite().copyToLocal(self.list_path(path), local_destination)) + def get_merge(self, path, local_destination): + """ + Using snakebite getmerge to implement this. + :param path: HDFS directory + :param local_destination: path on the system running Luigi + :return: merge of the directory + """ + return list(self.get_bite().getmerge(path=path, dst=local_destination)) + def mkdir(self, path, parents=True, mode=0o755, raise_if_exists=False): """ Use snakebite.mkdir, if available. diff --git a/luigi/contrib/hive.py b/luigi/contrib/hive.py index a310677b12..91c84e320f 100644 --- a/luigi/contrib/hive.py +++ b/luigi/contrib/hive.py @@ -68,7 +68,7 @@ def run_hive(args, check_return_code=True): if check_return_code and p.returncode != 0: raise HiveCommandError("Hive command: {0} failed with error code: {1}".format(" ".join(cmd), p.returncode), stdout, stderr) - return stdout + return stdout.decode('utf-8') def run_hive_cmd(hivecmd, check_return_code=True): diff --git a/luigi/contrib/kubernetes.py b/luigi/contrib/kubernetes.py index 0714bbd21b..f32f2b238d 100644 --- a/luigi/contrib/kubernetes.py +++ b/luigi/contrib/kubernetes.py @@ -32,18 +32,19 @@ Written and maintained by Marco Capuccini (@mcapuccini). """ - -import luigi import logging -import uuid import time +import uuid +from datetime import datetime + +import luigi logger = logging.getLogger('luigi-interface') try: from pykube.config import KubeConfig from pykube.http import HTTPClient - from pykube.objects import Job + from pykube.objects import Job, Pod except ImportError: logger.warning('pykube is not installed. KubernetesJobTask requires pykube.') @@ -61,21 +62,21 @@ class kubernetes(luigi.Config): class KubernetesJobTask(luigi.Task): - __POLL_TIME = 5 # see __track_job - kubernetes_config = kubernetes() + _kubernetes_config = None # Needs to be loaded at runtime def _init_kubernetes(self): self.__logger = logger self.__logger.debug("Kubernetes auth method: " + self.auth_method) - if(self.auth_method == "kubeconfig"): + if self.auth_method == "kubeconfig": self.__kube_api = HTTPClient(KubeConfig.from_file(self.kubeconfig_path)) - elif(self.auth_method == "service-account"): + elif self.auth_method == "service-account": self.__kube_api = HTTPClient(KubeConfig.from_service_account()) else: raise ValueError("Illegal auth_method") self.job_uuid = str(uuid.uuid4().hex) - self.uu_name = self.name + "-luigi-" + self.job_uuid + now = datetime.utcnow() + self.uu_name = "%s-%s-%s" % (self.name, now.strftime('%Y%m%d%H%M%S'), self.job_uuid[:16]) @property def auth_method(self): @@ -157,18 +158,60 @@ def max_retrials(self): """ return self.kubernetes_config.max_retrials + @property + def backoff_limit(self): + """ + Maximum number of retries before considering the job as failed. + See: https://kubernetes.io/docs/concepts/workloads/controllers/jobs-run-to-completion/#pod-backoff-failure-policy + """ + return 6 + + @property + def delete_on_success(self): + """ + Delete the Kubernetes workload if the job has ended successfully. + """ + return True + + @property + def print_pod_logs_on_exit(self): + """ + Fetch and print the pod logs once the job is completed. + """ + return False + + @property + def active_deadline_seconds(self): + """ + Time allowed to successfully schedule pods. + See: https://kubernetes.io/docs/concepts/workloads/controllers/jobs-run-to-completion/#job-termination-and-cleanup + """ + return 100 + + @property + def kubernetes_config(self): + if not self._kubernetes_config: + self._kubernetes_config = kubernetes() + return self._kubernetes_config + def __track_job(self): """Poll job status while active""" - while (self.__get_job_status() == "running"): - self.__logger.debug("Kubernetes job " + self.uu_name - + " is still running") + while not self.__verify_job_has_started(): time.sleep(self.__POLL_TIME) - if(self.__get_job_status() == "succeeded"): - self.__logger.info("Kubernetes job " + self.uu_name + " succeeded") - # Use signal_complete to notify of job completion - self.signal_complete() - else: - raise RuntimeError("Kubernetes job " + self.uu_name + " failed") + self.__logger.debug("Waiting for Kubernetes job " + self.uu_name + " to start") + self.__print_kubectl_hints() + + status = self.__get_job_status() + while status == "RUNNING": + self.__logger.debug("Kubernetes job " + self.uu_name + " is running") + time.sleep(self.__POLL_TIME) + status = self.__get_job_status() + + assert status != "FAILED", "Kubernetes job " + self.uu_name + " failed" + + # status == "SUCCEEDED" + self.__logger.info("Kubernetes job " + self.uu_name + " succeeded") + self.signal_complete() def signal_complete(self): """Signal job completion for scheduler and dependent tasks. @@ -181,27 +224,96 @@ def signal_complete(self): """ pass + def __get_pods(self): + pod_objs = Pod.objects(self.__kube_api) \ + .filter(selector="job-name=" + self.uu_name) \ + .response['items'] + return [Pod(self.__kube_api, p) for p in pod_objs] + + def __get_job(self): + jobs = Job.objects(self.__kube_api) \ + .filter(selector="luigi_task_id=" + self.job_uuid) \ + .response['items'] + assert len(jobs) == 1, "Kubernetes job " + self.uu_name + " not found" + return Job(self.__kube_api, jobs[0]) + + def __print_pod_logs(self): + for pod in self.__get_pods(): + logs = pod.logs(timestamps=True).strip() + self.__logger.info("Fetching logs from " + pod.name) + if len(logs) > 0: + for l in logs.split('\n'): + self.__logger.info(l) + + def __print_kubectl_hints(self): + self.__logger.info("To stream Pod logs, use:") + for pod in self.__get_pods(): + self.__logger.info("`kubectl logs -f pod/%s`" % pod.name) + + def __verify_job_has_started(self): + """Asserts that the job has successfully started""" + # Verify that the job started + self.__get_job() + + # Verify that the pod started + pods = self.__get_pods() + + assert len(pods) > 0, "No pod scheduled by " + self.uu_name + for pod in pods: + status = pod.obj['status'] + for cont_stats in status.get('containerStatuses', []): + if 'terminated' in cont_stats['state']: + t = cont_stats['state']['terminated'] + err_msg = "Pod %s %s (exit code %d). Logs: `kubectl logs pod/%s`" % ( + pod.name, t['reason'], t['exitCode'], pod.name) + assert t['exitCode'] == 0, err_msg + + if 'waiting' in cont_stats['state']: + wr = cont_stats['state']['waiting']['reason'] + assert wr == 'ContainerCreating', "Pod %s %s. Logs: `kubectl logs pod/%s`" % ( + pod.name, wr, pod.name) + + for cond in status.get('conditions', []): + if 'message' in cond: + if cond['reason'] == 'ContainersNotReady': + return False + assert cond['status'] != 'False', \ + "[ERROR] %s - %s" % (cond['reason'], cond['message']) + return True + def __get_job_status(self): """Return the Kubernetes job status""" - # Look for the required job - jobs = Job.objects(self.__kube_api).filter(selector="luigi_task_id=" - + self.job_uuid) - # Raise an exception if no such job found - if len(jobs.response["items"]) == 0: - raise RuntimeError("Kubernetes job " + self.uu_name + " not found") # Figure out status and return it - job = Job(self.__kube_api, jobs.response["items"][0]) - if ("succeeded" in job.obj["status"] and job.obj["status"]["succeeded"] > 0): - job.scale(replicas=0) # Downscale the job, but keep it for logging - return "succeeded" - if ("failed" in job.obj["status"]): + job = self.__get_job() + + if "succeeded" in job.obj["status"] and job.obj["status"]["succeeded"] > 0: + job.scale(replicas=0) + if self.print_pod_logs_on_exit: + self.__print_pod_logs() + if self.delete_on_success: + self.__delete_job_cascade(job) + return "SUCCEEDED" + + if "failed" in job.obj["status"]: failed_cnt = job.obj["status"]["failed"] self.__logger.debug("Kubernetes job " + self.uu_name + " status.failed: " + str(failed_cnt)) - if (failed_cnt > self.max_retrials): + if self.print_pod_logs_on_exit: + self.__print_pod_logs() + if failed_cnt > self.max_retrials: job.scale(replicas=0) # avoid more retrials - return "failed" - return "running" + return "FAILED" + return "RUNNING" + + def __delete_job_cascade(self, job): + delete_options_cascade = { + "kind": "DeleteOptions", + "apiVersion": "v1", + "propagationPolicy": "Background" + } + r = self.__kube_api.delete(json=delete_options_cascade, **job.api_kwargs()) + if r.status_code != 200: + self.__kube_api.raise_for_status(r) def run(self): self._init_kubernetes() @@ -217,6 +329,8 @@ def run(self): } }, "spec": { + "activeDeadlineSeconds": self.active_deadline_seconds, + "backoffLimit": self.backoff_limit, "template": { "metadata": { "name": self.uu_name @@ -228,7 +342,7 @@ def run(self): # Update user labels job_json['metadata']['labels'].update(self.labels) # Add default restartPolicy if not specified - if ("restartPolicy" not in self.spec_schema): + if "restartPolicy" not in self.spec_schema: job_json["spec"]["template"]["spec"]["restartPolicy"] = "Never" # Submit job self.__logger.info("Submitting Kubernetes Job: " + self.uu_name) diff --git a/luigi/contrib/mongodb.py b/luigi/contrib/mongodb.py index 3c60dbd829..7fa44cca80 100644 --- a/luigi/contrib/mongodb.py +++ b/luigi/contrib/mongodb.py @@ -166,7 +166,7 @@ def get_empty_ids(self): {'_id': True} ) - return set(self._document_ids) - set([doc['_id'] for doc in cursor]) + return set(self._document_ids) - {doc['_id'] for doc in cursor} class MongoCollectionTarget(MongoTarget): diff --git a/luigi/contrib/opener.py b/luigi/contrib/opener.py index 9640c55164..7607583c52 100644 --- a/luigi/contrib/opener.py +++ b/luigi/contrib/opener.py @@ -69,7 +69,7 @@ class InvalidQuery(OpenerError): class OpenerRegistry(object): - def __init__(self, openers=[]): + def __init__(self, openers=None): """An opener registry that stores a number of opener objects used to parse Target URIs @@ -77,6 +77,9 @@ def __init__(self, openers=[]): :type openers: list """ + if openers is None: + openers = [] + self.registry = {} self.openers = {} self.default_opener = 'file' diff --git a/luigi/contrib/redshift.py b/luigi/contrib/redshift.py index 2d3a80c439..d74468336d 100644 --- a/luigi/contrib/redshift.py +++ b/luigi/contrib/redshift.py @@ -267,6 +267,16 @@ def prune(self, connection): finally: cursor.close() + def create_schema(self, connection): + """ + Will create the schema in the database + """ + if '.' not in self.table: + return + + query = 'CREATE SCHEMA IF NOT EXISTS {schema_name};'.format(schema_name=self.table.split('.')[0]) + connection.cursor().execute(query) + def create_table(self, connection): """ Override to provide code for creating the target table. @@ -298,6 +308,24 @@ def create_table(self, connection): coldefs=coldefs, table_attributes=self.table_attributes) + connection.cursor().execute(query) + elif len(self.columns[0]) == 3: + # if columns is specified as (name, type, encoding) tuples + # possible column encodings: https://docs.aws.amazon.com/redshift/latest/dg/c_Compression_encodings.html + coldefs = ','.join( + '{name} {type} ENCODE {encoding}'.format( + name=name, + type=type, + encoding=encoding) for name, type, encoding in self.columns + ) + query = ("CREATE {type} TABLE " + "{table} ({coldefs}) " + "{table_attributes}").format( + type=self.table_type, + table=self.table, + coldefs=coldefs, + table_attributes=self.table_attributes) + connection.cursor().execute(query) else: raise ValueError("create_table() found no columns for %r" @@ -335,7 +363,7 @@ def copy(self, cursor, f): """ logger.info("Inserting file: %s", f) colnames = '' - if len(self.columns) > 0: + if self.columns and len(self.columns) > 0: colnames = ",".join([x[0] for x in self.columns]) colnames = '({})'.format(colnames) @@ -365,6 +393,27 @@ def output(self): table=self.table, update_id=self.update_id) + def does_schema_exist(self, connection): + """ + Determine whether the schema already exists. + """ + + if '.' in self.table: + query = ("select 1 as schema_exists " + "from pg_namespace " + "where nspname = lower(%s) limit 1") + else: + return True + + cursor = connection.cursor() + try: + schema = self.table.split('.')[0] + cursor.execute(query, [schema]) + result = cursor.fetchone() + return bool(result) + finally: + cursor.close() + def does_table_exist(self, connection): """ Determine whether the table already exists. @@ -390,9 +439,12 @@ def init_copy(self, connection): """ Perform pre-copy sql - such as creating table, truncating, or removing data older than x. """ + if not self.does_schema_exist(connection): + logger.info("Creating schema for %s", self.table) + self.create_schema(connection) + if not self.does_table_exist(connection): logger.info("Creating table %s", self.table) - connection.reset() self.create_table(connection) if self.do_truncate_table: @@ -687,12 +739,10 @@ def run(self): credentials=self._credentials()) logger.info('Executing unload query from task: {name}'.format(name=self.__class__)) - try: - cursor = connection.cursor() - cursor.execute(unload_query) - logger.info(cursor.statusmessage) - except: - raise + + cursor = connection.cursor() + cursor.execute(unload_query) + logger.info(cursor.statusmessage) # Update marker table self.output().touch(connection) diff --git a/luigi/contrib/s3.py b/luigi/contrib/s3.py index 890b25abe0..88be62cd75 100644 --- a/luigi/contrib/s3.py +++ b/luigi/contrib/s3.py @@ -17,25 +17,25 @@ """ Implementation of Simple Storage Service support. :py:class:`S3Target` is a subclass of the Target class to support S3 file -system operations. The `boto` library is required to use S3 targets. +system operations. The `boto3` library is required to use S3 targets. """ from __future__ import division import datetime +import io import itertools import logging import os import os.path +import warnings -import time -from multiprocessing.pool import ThreadPool +import botocore try: from urlparse import urlsplit except ImportError: from urllib.parse import urlsplit -import warnings try: from ConfigParser import NoSectionError @@ -43,11 +43,10 @@ from configparser import NoSectionError from luigi import six -from luigi.six.moves import range from luigi import configuration from luigi.format import get_default_format -from luigi.parameter import Parameter +from luigi.parameter import OptionalParameter, Parameter from luigi.target import FileAlreadyExists, FileSystem, FileSystemException, FileSystemTarget, AtomicLocalFile, MissingParentDirectory from luigi.task import ExternalTask @@ -68,16 +67,34 @@ class FileNotFoundException(FileSystemException): pass +class DeprecatedBotoClientException(Exception): + pass + + +class _StreamingBodyAdaptor(io.IOBase): + """ + Adapter class wrapping botocore's StreamingBody to make a file like iterable + """ + + def __init__(self, streaming_body): + self.streaming_body = streaming_body + + def read(self, size): + return self.streaming_body.read(size) + + def close(self): + return self.streaming_body.close() + + class S3Client(FileSystem): """ - boto-powered S3 client. + boto3-powered S3 client. """ _s3 = None def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, **kwargs): - from boto.s3.key import Key options = self._get_s3_config() options.update(kwargs) if aws_access_key_id: @@ -85,14 +102,12 @@ def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, if aws_secret_access_key: options['aws_secret_access_key'] = aws_secret_access_key - self.Key = Key self._options = options @property def s3(self): - # only import boto when needed to allow top-lvl s3 module import - import boto - import boto.s3.connection + # only import boto3 when needed to allow top-lvl s3 module import + import boto3 options = dict(self._options) @@ -109,21 +124,42 @@ def s3(self): aws_session_token = None if role_arn and role_session_name: - from boto import sts - - sts_client = sts.STSConnection() - assumed_role = sts_client.assume_role(role_arn, role_session_name) - aws_secret_access_key = assumed_role.credentials.secret_key - aws_access_key_id = assumed_role.credentials.access_key - aws_session_token = assumed_role.credentials.session_token - - for key in ['aws_access_key_id', 'aws_secret_access_key', 'aws_role_session_name', 'aws_role_arn']: + sts_client = boto3.client('sts') + assumed_role = sts_client.assume_role(RoleArn=role_arn, + RoleSessionName=role_session_name) + aws_secret_access_key = assumed_role['Credentials'].get( + 'SecretAccessKey') + aws_access_key_id = assumed_role['Credentials'].get('AccessKeyId') + aws_session_token = assumed_role['Credentials'].get('SessionToken') + logger.debug('using aws credentials via assumed role {} as defined in luigi config' + .format(role_session_name)) + + for key in ['aws_access_key_id', 'aws_secret_access_key', + 'aws_role_session_name', 'aws_role_arn']: if key in options: options.pop(key) - self._s3 = boto.s3.connection.S3Connection(aws_access_key_id, - aws_secret_access_key, - security_token=aws_session_token, - **options) + + # At this stage, if no credentials provided, boto3 would handle their resolution for us + # For finding out about the order in which it tries to find these credentials + # please see here details + # http://boto3.readthedocs.io/en/latest/guide/configuration.html#configuring-credentials + + if not (aws_access_key_id and aws_secret_access_key): + logger.debug('no credentials provided, delegating credentials resolution to boto3') + + try: + self._s3 = boto3.resource('s3', + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + **options) + except TypeError as e: + logger.error(e.message) + if 'got an unexpected keyword argument' in e.message: + raise DeprecatedBotoClientException( + "Now using boto3. Check that you're passing the correct arguments") + raise + return self._s3 @s3.setter @@ -136,16 +172,12 @@ def exists(self, path): """ (bucket, key) = self._path_to_bucket_and_key(path) - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - # root always exists if self._is_root(key): return True # file - s3_key = s3_bucket.get_key(key) - if s3_key: + if self._exists(bucket, key): return True if self.isdir(path): @@ -163,236 +195,153 @@ def remove(self, path, recursive=True): return False (bucket, key) = self._path_to_bucket_and_key(path) - + s3_bucket = self.s3.Bucket(bucket) # root if self._is_root(key): raise InvalidDeleteException('Cannot delete root of bucket at path %s' % path) - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - # file - s3_key = s3_bucket.get_key(key) - if s3_key: - s3_bucket.delete_key(s3_key) + if self._exists(bucket, key): + self.s3.meta.client.delete_object(Bucket=bucket, Key=key) logger.debug('Deleting %s from bucket %s', key, bucket) return True if self.isdir(path) and not recursive: raise InvalidDeleteException('Path %s is a directory. Must use recursive delete' % path) - delete_key_list = [ - k for k in s3_bucket.list(self._add_path_delimiter(key))] + delete_key_list = [{'Key': obj.key} for obj in s3_bucket.objects.filter(Prefix=self._add_path_delimiter(key))] # delete the directory marker file if it exists - s3_dir_with_suffix_key = s3_bucket.get_key(key + S3_DIRECTORY_MARKER_SUFFIX_0) - if s3_dir_with_suffix_key: - delete_key_list.append(s3_dir_with_suffix_key) + if self._exists(bucket, '{}{}'.format(key, S3_DIRECTORY_MARKER_SUFFIX_0)): + delete_key_list.append({'Key': '{}{}'.format(key, S3_DIRECTORY_MARKER_SUFFIX_0)}) if len(delete_key_list) > 0: - for k in delete_key_list: - logger.debug('Deleting %s from bucket %s', k, bucket) - s3_bucket.delete_keys(delete_key_list) + self.s3.meta.client.delete_objects(Bucket=bucket, Delete={'Objects': delete_key_list}) return True return False - def get_key(self, path): + def move(self, source_path, destination_path, **kwargs): """ - Returns just the key from the path. - - An s3 path is composed of a bucket and a key. + Rename/move an object from one S3 location to another. + :param kwargs: Keyword arguments are passed to the boto3 function `copy` + """ + self.copy(source_path, destination_path, **kwargs) + self.remove(source_path) - Suppose we have a path `s3://my_bucket/some/files/my_file`. The key is `some/files/my_file`. + def get_key(self, path): + """ + Returns the object summary at the path """ (bucket, key) = self._path_to_bucket_and_key(path) - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - return s3_bucket.get_key(key) + if self._exists(bucket, key): + return self.s3.ObjectSummary(bucket, key) def put(self, local_path, destination_s3_path, **kwargs): """ Put an object stored locally to an S3 path. - :param kwargs: Keyword arguments are passed to the boto function `set_contents_from_filename` + :param kwargs: Keyword arguments are passed to the boto function `put_object` """ - (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) + if 'encrypt_key' in kwargs: + raise DeprecatedBotoClientException( + 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') # put the file - s3_key = self.Key(s3_bucket) - s3_key.key = key - s3_key.set_contents_from_filename(local_path, **kwargs) + self.put_multipart(local_path, destination_s3_path, **kwargs) def put_string(self, content, destination_s3_path, **kwargs): """ Put a string to an S3 path. - - :param kwargs: Keyword arguments are passed to the boto function `set_contents_from_string` + :param kwargs: Keyword arguments are passed to the boto3 function `put_object` """ + if 'encrypt_key' in kwargs: + raise DeprecatedBotoClientException( + 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - # put the content - s3_key = self.Key(s3_bucket) - s3_key.key = key - s3_key.set_contents_from_string(content, **kwargs) + # validate the bucket + self._validate_bucket(bucket) - def put_multipart(self, local_path, destination_s3_path, part_size=67108864, **kwargs): + # put the file + self.s3.meta.client.put_object( + Key=key, Bucket=bucket, Body=content, **kwargs) + + def put_multipart(self, local_path, destination_s3_path, part_size=8388608, **kwargs): """ Put an object stored locally to an S3 path - using S3 multi-part upload (for files > 5GB). - + using S3 multi-part upload (for files > 8Mb). :param local_path: Path to source local file :param destination_s3_path: URL for target S3 location - :param part_size: Part size in bytes. Default: 67108864 (64MB), must be >= 5MB and <= 5 GB. - :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` + :param part_size: Part size in bytes. Default: 8388608 (8MB) + :param kwargs: Keyword arguments are passed to the boto function `upload_fileobj` as ExtraArgs """ - # calculate number of parts to upload - # based on the size of the file - source_size = os.stat(local_path).st_size + if 'encrypt_key' in kwargs: + raise DeprecatedBotoClientException( + 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') - if source_size <= part_size: - # fallback to standard, non-multipart strategy - return self.put(local_path, destination_s3_path, **kwargs) + import boto3 + # default part size for boto3 is 8Mb, changing it to fit part_size + # provided as a parameter + transfer_config = boto3.s3.transfer.TransferConfig( + multipart_chunksize=part_size) (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # calculate the number of parts (int division). - # use modulo to avoid float precision issues - # for exactly-sized fits - num_parts = (source_size + part_size - 1) // part_size - - mp = None - try: - mp = s3_bucket.initiate_multipart_upload(key, **kwargs) - - for i in range(num_parts): - # upload a part at a time to S3 - offset = part_size * i - bytes = min(part_size, source_size - offset) - with open(local_path, 'rb') as fp: - part_num = i + 1 - logger.info('Uploading part %s/%s to %s', part_num, num_parts, destination_s3_path) - fp.seek(offset) - mp.upload_part_from_file(fp, part_num=part_num, size=bytes) - - # finish the upload, making the file available in S3 - mp.complete_upload() - except BaseException: - if mp: - logger.info('Canceling multipart s3 upload for %s', destination_s3_path) - # cancel the upload so we don't get charged for - # storage consumed by uploaded parts - mp.cancel_upload() - raise - - def get(self, s3_path, destination_local_path): - """ - Get an object stored in S3 and write it to a local path. - """ - (bucket, key) = self._path_to_bucket_and_key(s3_path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - # download the file - s3_key = self.Key(s3_bucket) - s3_key.key = key - s3_key.get_contents_to_filename(destination_local_path) - - def get_as_string(self, s3_path): - """ - Get the contents of an object stored in S3 as a string. - """ - (bucket, key) = self._path_to_bucket_and_key(s3_path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) + # validate the bucket + self._validate_bucket(bucket) - # get the content - s3_key = self.Key(s3_bucket) - s3_key.key = key - contents = s3_key.get_contents_as_string() + self.s3.meta.client.upload_fileobj( + Fileobj=open(local_path, 'rb'), Bucket=bucket, Key=key, Config=transfer_config, ExtraArgs=kwargs) - return contents - - def copy(self, source_path, destination_path, threads=100, start_time=None, end_time=None, part_size=67108864, **kwargs): + def copy(self, source_path, destination_path, threads=100, start_time=None, end_time=None, part_size=8388608, **kwargs): """ Copy object(s) from one S3 location to another. Works for individual keys or entire directories. - When files are larger than `part_size`, multipart uploading will be used. - :param source_path: The `s3://` path of the directory or key to copy from :param destination_path: The `s3://` path of the directory or key to copy to :param threads: Optional argument to define the number of threads to use when copying (min: 3 threads) :param start_time: Optional argument to copy files with modified dates after start_time :param end_time: Optional argument to copy files with modified dates before end_time - :param part_size: Part size in bytes. Default: 67108864 (64MB), must be >= 5MB and <= 5 GB. - :param kwargs: Keyword arguments are passed to the boto function `copy_key` - + :param part_size: Part size in bytes + :param kwargs: Keyword arguments are passed to the boto function `copy` as ExtraArgs :returns tuple (number_of_files_copied, total_size_copied_in_bytes) """ + start = datetime.datetime.now() (src_bucket, src_key) = self._path_to_bucket_and_key(source_path) (dst_bucket, dst_key) = self._path_to_bucket_and_key(destination_path) - # As the S3 copy command is completely server side, there is no issue with issuing a lot of threads - # to issue a single API call per copy, however, this may in theory cause issues on systems with low ulimits for - # number of threads when copying really large files (e.g. with a ~100GB file this will open ~1500 - # threads), or large directories. Around 100 threads seems to work well. + # don't allow threads to be less than 3 + threads = 3 if threads < 3 else threads + import boto3 - threads = 3 if threads < 3 else threads # don't allow threads to be less than 3 + transfer_config = boto3.s3.transfer.TransferConfig( + max_concurrency=threads, multipart_chunksize=part_size) total_keys = 0 - copy_pool = ThreadPool(processes=threads) - if self.isdir(source_path): - # The management pool is to ensure that there's no deadlock between the s3 copying threads, and the - # multipart_copy threads that monitors each group of s3 copy threads and returns a success once the entire file - # is copied. Without this, we could potentially fill up the pool with threads waiting to check if the s3 copies - # have completed, leaving no available threads to actually perform any copying. - copy_jobs = [] - management_pool = ThreadPool(processes=threads) - (bucket, key) = self._path_to_bucket_and_key(source_path) key_path = self._add_path_delimiter(key) key_path_len = len(key_path) - - total_size_bytes = 0 src_prefix = self._add_path_delimiter(src_key) dst_prefix = self._add_path_delimiter(dst_key) + total_size_bytes = 0 for item in self.list(source_path, start_time=start_time, end_time=end_time, return_key=True): path = item.key[key_path_len:] # prevents copy attempt of empty key in folder if path != '' and path != '/': total_keys += 1 total_size_bytes += item.size - job = management_pool.apply_async(self.__copy_multipart, - args=(copy_pool, - src_bucket, src_prefix + path, - dst_bucket, dst_prefix + path, - part_size), - kwds=kwargs) - copy_jobs.append(job) - - # Wait for the pools to finish scheduling all the copies - management_pool.close() - management_pool.join() - copy_pool.close() - copy_pool.join() - - # Raise any errors encountered in any of the copy processes - for result in copy_jobs: - result.get() + copy_source = { + 'Bucket': src_bucket, + 'Key': src_prefix + path + } + + self.s3.meta.client.copy( + copy_source, dst_bucket, dst_prefix + path, Config=transfer_config, ExtraArgs=kwargs) end = datetime.datetime.now() duration = end - start @@ -403,122 +352,31 @@ def copy(self, source_path, destination_path, threads=100, start_time=None, end_ # If the file isn't a directory just perform a simple copy else: - self.__copy_multipart(copy_pool, src_bucket, src_key, dst_bucket, dst_key, part_size, **kwargs) - # Close the pool - copy_pool.close() - copy_pool.join() - - def __copy_multipart(self, pool, src_bucket, src_key, dst_bucket, dst_key, part_size, **kwargs): - """ - Copy a single S3 object to another S3 object, falling back to multipart copy where necessary - - NOTE: This is a private method and should only be called from within the `s3.copy` method - - :param pool: The threadpool to put the s3 copy processes onto - :param src_bucket: source bucket name - :param src_key: source key name - :param dst_bucket: destination bucket name - :param dst_key: destination key name - :param key_size: size of the key to copy in bytes - :param part_size: Part size in bytes. Must be >= 5MB and <= 5 GB. - :param kwargs: Keyword arguments are passed to the boto function `initiate_multipart_upload` - """ + copy_source = { + 'Bucket': src_bucket, + 'Key': src_key + } + self.s3.meta.client.copy( + copy_source, dst_bucket, dst_key, Config=transfer_config, ExtraArgs=kwargs) - source_bucket = self.s3.get_bucket(src_bucket, validate=True) - dest_bucket = self.s3.get_bucket(dst_bucket, validate=True) - - key_size = source_bucket.lookup(src_key).size - - # We can't do a multipart copy on an empty Key, so handle this specially. - # Also, don't bother using the multipart machinery if we're only dealing with a small non-multipart file - if key_size == 0 or key_size <= part_size: - result = pool.apply_async(dest_bucket.copy_key, args=(dst_key, src_bucket, src_key), kwds=kwargs) - # Bubble up any errors we may encounter - return result.get() - - mp = None - - try: - mp = dest_bucket.initiate_multipart_upload(dst_key, **kwargs) - cur_pos = 0 - - # Store the results from the apply_async in a list so we can check for failures - results = [] - - # Calculate the number of chunks the file will be - num_parts = (key_size + part_size - 1) // part_size - - for i in range(num_parts): - # Issue an S3 copy request, one part at a time, from one S3 object to another - part_start = cur_pos - cur_pos += part_size - part_end = min(cur_pos - 1, key_size - 1) - part_num = i + 1 - results.append(pool.apply_async(mp.copy_part_from_key, args=(src_bucket, src_key, part_num, part_start, part_end))) - logger.info('Requesting copy of %s/%s to %s/%s', part_num, num_parts, dst_bucket, dst_key) - - logger.info('Waiting for multipart copy of %s/%s to finish', dst_bucket, dst_key) - - # This will raise any exceptions in any of the copy threads - for result in results: - result.get() - - # finish the copy, making the file available in S3 - mp.complete_upload() - return mp.key_name - - except: - logger.info('Error during multipart s3 copy for %s/%s to %s/%s...', src_bucket, src_key, dst_bucket, dst_key) - # cancel the copy so we don't get charged for storage consumed by copied parts - if mp: - mp.cancel_upload() - raise - - def move(self, source_path, destination_path, **kwargs): + def get(self, s3_path, destination_local_path): """ - Rename/move an object from one S3 location to another. - - :param kwargs: Keyword arguments are passed to the boto function `copy_key` + Get an object stored in S3 and write it to a local path. """ - self.copy(source_path, destination_path, **kwargs) - self.remove(source_path) + (bucket, key) = self._path_to_bucket_and_key(s3_path) + # download the file + self.s3.meta.client.download_file(bucket, key, destination_local_path) - def listdir(self, path, start_time=None, end_time=None, return_key=False): + def get_as_string(self, s3_path): """ - Get an iterable with S3 folder contents. - Iterable contains paths relative to queried path. - - :param start_time: Optional argument to list files with modified dates after start_time - :param end_time: Optional argument to list files with modified dates before end_time - :param return_key: Optional argument, when set to True will return a boto.s3.key.Key (instead of the filename) + Get the contents of an object stored in S3 as a string. """ - (bucket, key) = self._path_to_bucket_and_key(path) - - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) - - key_path = self._add_path_delimiter(key) - key_path_len = len(key_path) - for item in s3_bucket.list(prefix=key_path): - last_modified_date = time.strptime(item.last_modified, "%Y-%m-%dT%H:%M:%S.%fZ") - if ( - (not start_time and not end_time) or # neither are defined, list all - (start_time and not end_time and start_time < last_modified_date) or # start defined, after start - (not start_time and end_time and last_modified_date < end_time) or # end defined, prior to end - (start_time and end_time and start_time < last_modified_date < end_time) # both defined, between - ): - if return_key: - yield item - else: - yield self._add_path_delimiter(path) + item.key[key_path_len:] + (bucket, key) = self._path_to_bucket_and_key(s3_path) - def list(self, path, start_time=None, end_time=None, return_key=False): # backwards compat - key_path_len = len(self._add_path_delimiter(path)) - for item in self.listdir(path, start_time=start_time, end_time=end_time, return_key=return_key): - if return_key: - yield item - else: - yield item[key_path_len:] + # get the content + obj = self.s3.Object(bucket, key) + contents = obj.get()['Body'].read().decode('utf-8') + return contents def isdir(self, path): """ @@ -526,8 +384,7 @@ def isdir(self, path): """ (bucket, key) = self._path_to_bucket_and_key(path) - # grab and validate the bucket - s3_bucket = self.s3.get_bucket(bucket, validate=True) + s3_bucket = self.s3.Bucket(bucket) # root is a directory if self._is_root(key): @@ -535,13 +392,19 @@ def isdir(self, path): for suffix in (S3_DIRECTORY_MARKER_SUFFIX_0, S3_DIRECTORY_MARKER_SUFFIX_1): - s3_dir_with_suffix_key = s3_bucket.get_key(key + suffix) - if s3_dir_with_suffix_key: + try: + self.s3.meta.client.get_object( + Bucket=bucket, Key=key + suffix) + except botocore.exceptions.ClientError as e: + if not e.response['Error']['Code'] in ['NoSuchKey', '404']: + raise + else: return True # files with this prefix key_path = self._add_path_delimiter(key) - s3_bucket_list_result = list(itertools.islice(s3_bucket.list(prefix=key_path), 1)) + s3_bucket_list_result = list(itertools.islice( + s3_bucket.objects.filter(Prefix=key_path), 1)) if s3_bucket_list_result: return True @@ -553,16 +416,57 @@ def mkdir(self, path, parents=True, raise_if_exists=False): if raise_if_exists and self.isdir(path): raise FileAlreadyExists() - _, key = self._path_to_bucket_and_key(path) + bucket, key = self._path_to_bucket_and_key(path) if self._is_root(key): - return # isdir raises if the bucket doesn't exist; nothing to do here. + # isdir raises if the bucket doesn't exist; nothing to do here. + return - key = self._add_path_delimiter(key) + path = self._add_path_delimiter(path) - if not parents and not self.isdir(os.path.dirname(key)): + if not parents and not self.isdir(os.path.dirname(path)): raise MissingParentDirectory() - return self.put_string("", self._add_path_delimiter(path)) + return self.put_string("", path) + + def listdir(self, path, start_time=None, end_time=None, return_key=False): + """ + Get an iterable with S3 folder contents. + Iterable contains paths relative to queried path. + :param start_time: Optional argument to list files with modified (offset aware) datetime after start_time + :param end_time: Optional argument to list files with modified (offset aware) datetime before end_time + :param return_key: Optional argument, when set to True will return boto3's ObjectSummary (instead of the filename) + """ + (bucket, key) = self._path_to_bucket_and_key(path) + + # grab and validate the bucket + s3_bucket = self.s3.Bucket(bucket) + + key_path = self._add_path_delimiter(key) + key_path_len = len(key_path) + for item in s3_bucket.objects.filter(Prefix=key_path): + last_modified_date = item.last_modified + if ( + # neither are defined, list all + (not start_time and not end_time) or + # start defined, after start + (start_time and not end_time and start_time < last_modified_date) or + # end defined, prior to end + (not start_time and end_time and last_modified_date < end_time) or + (start_time and end_time and start_time < + last_modified_date < end_time) # both defined, between + ): + if return_key: + yield item + else: + yield self._add_path_delimiter(path) + item.key[key_path_len:] + + def list(self, path, start_time=None, end_time=None, return_key=False): # backwards compat + key_path_len = len(self._add_path_delimiter(path)) + for item in self.listdir(path, start_time=start_time, end_time=end_time, return_key=return_key): + if return_key: + yield item + else: + yield item[key_path_len:] def _get_s3_config(self, key=None): defaults = dict(configuration.get_config().defaults()) @@ -579,6 +483,7 @@ def _get_s3_config(self, key=None): if key: return config.get(key) section_only = {k: v for k, v in config.items() if k not in defaults or v != defaults[k]} + return section_only def _path_to_bucket_and_key(self, path): @@ -592,6 +497,33 @@ def _is_root(self, key): def _add_path_delimiter(self, key): return key if key[-1:] == '/' or key == '' else key + '/' + def _validate_bucket(self, bucket_name): + exists = True + + try: + self.s3.meta.client.head_bucket(Bucket=bucket_name) + except botocore.exceptions.ClientError as e: + error_code = e.response['Error']['Code'] + if error_code in ('404', 'NoSuchBucket'): + exists = False + else: + raise + return exists + + def _exists(self, bucket, key): + s3_key = False + try: + self.s3.Object(bucket, key).load() + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] in ['NoSuchKey', '404']: + s3_key = False + else: + raise + else: + s3_key = True + if s3_key: + return True + class AtomicS3File(AtomicLocalFile): """ @@ -606,27 +538,19 @@ def __init__(self, path, s3_client, **kwargs): self.s3_options = kwargs def move_to_final_destination(self): - self.s3_client.put_multipart(self.tmp_path, self.path, **self.s3_options) + self.s3_client.put_multipart( + self.tmp_path, self.path, **self.s3_options) class ReadableS3File(object): def __init__(self, s3_key): - self.s3_key = s3_key + self.s3_key = _StreamingBodyAdaptor(s3_key.get()['Body']) self.buffer = [] self.closed = False self.finished = False - def read(self, size=0): - f = self.s3_key.read(size=size) - - # boto will loop on the key forever and it's not what is expected by - # the python io interface - # boto/boto#2805 - if f == b'': - self.finished = True - if self.finished: - return b'' - + def read(self, size=None): + f = self.s3_key.read(size) return f def close(self): @@ -716,7 +640,8 @@ def open(self, mode='r'): if mode == 'r': s3_key = self.fs.get_key(self.path) if not s3_key: - raise FileNotFoundException("Could not find file at %s" % self.path) + raise FileNotFoundException( + "Could not find file at %s" % self.path) fileobj = ReadableS3File(s3_key) return self.format.pipe_reader(fileobj) @@ -752,8 +677,6 @@ def __init__(self, path, format=None, client=None, flag='_SUCCESS'): :param path: the directory where the files are stored. :type path: str - :param format: see the luigi.format module for options - :type format: luigi.format.[Text|UTF8|Nop] :param client: :type client: :param flag: @@ -808,7 +731,7 @@ class S3FlagTask(ExternalTask): An external task that requires the existence of EMR output in S3. """ path = Parameter() - flag = Parameter(default=None) + flag = OptionalParameter(default=None) def output(self): return S3FlagTarget(self.path, flag=self.flag) diff --git a/luigi/contrib/sqla.py b/luigi/contrib/sqla.py index 72e35d3c16..ed0f62a6e3 100644 --- a/luigi/contrib/sqla.py +++ b/luigi/contrib/sqla.py @@ -163,7 +163,7 @@ class SQLAlchemyTarget(luigi.Target): _engine_dict = {} # dict of sqlalchemy engine instances Connection = collections.namedtuple("Connection", "engine pid") - def __init__(self, connection_string, target_table, update_id, echo=False, connect_args={}): + def __init__(self, connection_string, target_table, update_id, echo=False, connect_args=None): """ Constructor for the SQLAlchemyTarget. @@ -179,6 +179,9 @@ def __init__(self, connection_string, target_table, update_id, echo=False, conne :type connect_args: dict :return: """ + if connect_args is None: + connect_args = {} + self.target_table = target_table self.update_id = update_id self.connection_string = connection_string diff --git a/luigi/db_task_history.py b/luigi/db_task_history.py index 3f175c25c9..ed313a952a 100644 --- a/luigi/db_task_history.py +++ b/luigi/db_task_history.py @@ -70,7 +70,7 @@ def _session(self, session=None): session = self.session_factory() try: yield session - except: + except BaseException: session.rollback() raise else: @@ -196,7 +196,7 @@ class TaskParameter(Base): __tablename__ = 'task_parameters' task_id = sqlalchemy.Column(sqlalchemy.Integer, sqlalchemy.ForeignKey('tasks.id'), primary_key=True) name = sqlalchemy.Column(sqlalchemy.String(128), primary_key=True) - value = sqlalchemy.Column(sqlalchemy.String(256)) + value = sqlalchemy.Column(sqlalchemy.Text()) def __repr__(self): return "TaskParameter(task_id=%d, name=%s, value=%s)" % (self.task_id, self.name, self.value) @@ -247,10 +247,35 @@ def _upgrade_schema(engine): :param engine: SQLAlchemy engine of the underlying database. """ inspector = reflection.Inspector.from_engine(engine) - conn = engine.connect() - - # Upgrade 1. Add task_id column and index to tasks - if 'task_id' not in [x['name'] for x in inspector.get_columns('tasks')]: - logger.warn('Upgrading DbTaskHistory schema: Adding tasks.task_id') - conn.execute('ALTER TABLE tasks ADD COLUMN task_id VARCHAR(200)') - conn.execute('CREATE INDEX ix_task_id ON tasks (task_id)') + with engine.connect() as conn: + + # Upgrade 1. Add task_id column and index to tasks + if 'task_id' not in [x['name'] for x in inspector.get_columns('tasks')]: + logger.warning('Upgrading DbTaskHistory schema: Adding tasks.task_id') + conn.execute('ALTER TABLE tasks ADD COLUMN task_id VARCHAR(200)') + conn.execute('CREATE INDEX ix_task_id ON tasks (task_id)') + + # Upgrade 2. Alter value column to be TEXT, note that this is idempotent so no if-guard + if 'mysql' in engine.dialect.name: + conn.execute('ALTER TABLE task_parameters MODIFY COLUMN value TEXT') + elif 'oracle' in engine.dialect.name: + conn.execute('ALTER TABLE task_parameters MODIFY value TEXT') + elif 'mssql' in engine.dialect.name: + conn.execute('ALTER TABLE task_parameters ALTER COLUMN value TEXT') + elif 'postgresql' in engine.dialect.name: + conn.execute('ALTER TABLE task_parameters ALTER COLUMN value TYPE TEXT') + elif 'sqlite' in engine.dialect.name: + # SQLite does not support changing column types. A database file will need + # to be used to pickup this migration change. + for i in conn.execute('PRAGMA table_info(task_parameters);').fetchall(): + if i['name'] == 'value' and i['type'] != 'TEXT': + logger.warning( + 'SQLite can not change column types. Please use a new database ' + 'to pickup column type changes.' + ) + else: + logger.warning( + 'SQLAlcheny dialect {} could not be migrated to the TEXT type'.format( + engine.dialect + ) + ) diff --git a/luigi/execution_summary.py b/luigi/execution_summary.py index 6a5ecf0fa3..886eb067d4 100644 --- a/luigi/execution_summary.py +++ b/luigi/execution_summary.py @@ -275,8 +275,8 @@ def _get_comments(group_tasks): "not_run", ) _PENDING_SUB_STATUSES = set(_ORDERED_STATUSES[_ORDERED_STATUSES.index("still_pending_ext"):]) -_COMMENTS = set(( - ("already_done", 'present dependencies were encountered'), +_COMMENTS = { + ("already_done", 'complete ones were encountered'), ("completed", 'ran successfully'), ("failed", 'failed'), ("scheduling_error", 'failed scheduling'), @@ -284,11 +284,11 @@ def _get_comments(group_tasks): ("still_pending_ext", 'were missing external dependencies'), ("run_by_other_worker", 'were being run by another worker'), ("upstream_failure", 'had failed dependencies'), - ("upstream_missing_dependency", 'had missing external dependencies'), + ("upstream_missing_dependency", 'had missing dependencies'), ("upstream_run_by_other_worker", 'had dependencies that were being run by other worker'), ("upstream_scheduling_error", 'had dependencies whose scheduling failed'), ("not_run", 'was not granted run permission by the scheduler'), -)) +} def _get_run_by_other_worker(worker): @@ -399,7 +399,7 @@ def _summary_format(set_tasks, worker): reason = "there were missing external dependencies" else: smiley = ":)" - reason = "there were no failed tasks or missing external dependencies" + reason = "there were no failed tasks or missing dependencies" str_output += "\nThis progress looks {0} because {1}".format(smiley, reason) if num_all_tasks == 0: str_output = 'Did not schedule any tasks' diff --git a/luigi/lock.py b/luigi/lock.py index d26a82c517..1b31ed0c90 100644 --- a/luigi/lock.py +++ b/luigi/lock.py @@ -37,7 +37,7 @@ def getpcmd(pid): """ if os.name == "nt": # Use wmic command instead of ps on Windows. - cmd = 'wmic path win32_process where ProcessID=%s get Commandline' % (pid, ) + cmd = 'wmic path win32_process where ProcessID=%s get Commandline 2> nul' % (pid, ) with os.popen(cmd, 'r') as p: lines = [line for line in p.readlines() if line.strip("\r\n ") != ""] if lines: diff --git a/luigi/parameter.py b/luigi/parameter.py index e5c5305447..a711156326 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -307,6 +307,26 @@ def _parser_action(): return "store" +class OptionalParameter(Parameter): + """ A Parameter that treats empty string as None """ + + def serialize(self, x): + if x is None: + return '' + else: + return str(x) + + def parse(self, x): + return x or None + + def _warn_on_wrong_param_type(self, param_name, param_value): + if self.__class__ != OptionalParameter: + return + if not isinstance(param_value, six.string_types) and param_value is not None: + warnings.warn('OptionalParameter "{}" with value "{}" is not of type string or None.'.format( + param_name, param_value)) + + _UNIX_EPOCH = datetime.datetime.utcfromtimestamp(0) @@ -484,6 +504,12 @@ def serialize(self, dt): return str(dt) return dt.strftime(self.date_format) + @staticmethod + def _convert_to_dt(dt): + if not isinstance(dt, datetime.datetime): + dt = datetime.datetime.combine(dt, datetime.time.min) + return dt + def normalize(self, dt): """ Clamp dt to every Nth :py:attr:`~_DatetimeParameterBase.interval` starting at @@ -492,6 +518,8 @@ def normalize(self, dt): if dt is None: return None + dt = self._convert_to_dt(dt) + dt = dt.replace(microsecond=0) # remove microseconds, to avoid float rounding issues. delta = (dt - self.start).total_seconds() granularity = (self._timedelta * self.interval).total_seconds() diff --git a/luigi/retcodes.py b/luigi/retcodes.py index 1b9c778c41..ef50329559 100644 --- a/luigi/retcodes.py +++ b/luigi/retcodes.py @@ -80,9 +80,10 @@ def run_with_retcodes(argv): logger.exception("Uncaught exception in luigi") sys.exit(retcodes.unhandled_exception) - task_sets = luigi.execution_summary._summary_dict(worker) - root_task = luigi.execution_summary._root_task(worker) - non_empty_categories = {k: v for k, v in task_sets.items() if v}.keys() + with luigi.cmdline_parser.CmdlineParser.global_instance(argv): + task_sets = luigi.execution_summary._summary_dict(worker) + root_task = luigi.execution_summary._root_task(worker) + non_empty_categories = {k: v for k, v in task_sets.items() if v}.keys() def has(status): assert status in luigi.execution_summary._ORDERED_STATUSES diff --git a/luigi/rpc.py b/luigi/rpc.py index 7014eb0c36..a18bd58ded 100644 --- a/luigi/rpc.py +++ b/luigi/rpc.py @@ -19,6 +19,7 @@ rpc.py implements the client side of it, server.py implements the server side. See :doc:`/central_scheduler` for more info. """ +import os import json import logging import socket @@ -80,8 +81,17 @@ def __init__(self, session): from requests import exceptions as requests_exceptions self.raises = requests_exceptions.RequestException self.session = session + self.process_id = os.getpid() + + def check_pid(self): + # if the process id change changed from when the session was created + # a new session needs to be setup since requests isn't multiprocessing safe. + if os.getpid() != self.process_id: + self.session = requests.Session() + self.process_id = os.getpid() def fetch(self, full_url, body, timeout): + self.check_pid() resp = self.session.get(full_url, data=body, timeout=timeout) resp.raise_for_status() return resp.text @@ -131,7 +141,8 @@ def _fetch(self, url_suffix, body, log_exceptions=True): except self._fetcher.raises as e: last_exception = e if log_exceptions: - logger.exception("Failed connecting to remote scheduler %r", self._url) + logger.warning("Failed connecting to remote scheduler %r", self._url, + exc_info=True) continue else: raise RPCError( diff --git a/luigi/scheduler.py b/luigi/scheduler.py index 7334c77958..bb49f9df73 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -38,6 +38,7 @@ import os import re import time +import uuid from luigi import six @@ -147,6 +148,10 @@ class scheduler(Config): prune_on_get_work = parameter.BoolParameter(default=False) + pause_enabled = parameter.BoolParameter(default=True) + + send_messages = parameter.BoolParameter(default=True) + def _get_retry_policy(self): return RetryPolicy(self.retry_count, self.disable_hard_timeout, self.disable_window) @@ -275,8 +280,8 @@ def __eq__(self, other): class Task(object): def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None, - params=None, tracking_url=None, status_message=None, progress_percentage=None, retry_policy='notoptional', - public_params=None, hidden_params=None): + params=None, accepts_messages=False, tracking_url=None, status_message=None, + progress_percentage=None, retry_policy='notoptional'): self.id = task_id self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active) self.workers = OrderedSet() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active @@ -297,15 +302,13 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', self.family = family self.module = module self.params = _get_default(params, {}) - - self.public_params = _get_default(public_params, {}) - self.hidden_params = _get_default(hidden_params, {}) - + self.accepts_messages = accepts_messages self.retry_policy = retry_policy self.failures = Failures(self.retry_policy.disable_window) self.tracking_url = tracking_url self.status_message = status_message self.progress_percentage = progress_percentage + self.scheduler_message_responses = {} self.scheduler_disable_time = None self.runnable = False self.batchable = False @@ -339,8 +342,8 @@ def has_excessive_failures(self): @property def pretty_id(self): - param_str = ', '.join('{}={}'.format(key, value) for key, value in sorted(self.params.items())) - return '{}({})'.format(self.family, param_str) + param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.params.items())) + return u'{}({})'.format(self.family, param_str) class Worker(object): @@ -717,7 +720,7 @@ def _prune_workers(self): self._state.inactivate_workers(remove_workers) def _prune_tasks(self): - assistant_ids = set(w.id for w in self._state.get_assistants()) + assistant_ids = {w.id for w in self._state.get_assistants()} remove_tasks = [] for task in self._state.get_active_tasks(): @@ -774,9 +777,9 @@ def forgive_failures(self, task_id=None): @rpc_method() def add_task(self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, - priority=0, family='', module=None, params=None, + priority=0, family='', module=None, params=None, accepts_messages=False, assistant=False, tracking_url=None, worker=None, batchable=None, - batch_id=None, retry_policy_dict={}, owners=None, **kwargs): + batch_id=None, retry_policy_dict=None, owners=None, **kwargs): """ * add task identified by task_id if it doesn't exist * if deps is not None, update dependency list @@ -787,6 +790,12 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, assert worker is not None worker_id = worker worker = self._update_worker(worker_id) + + resources = {} if resources is None else resources.copy() + + if retry_policy_dict is None: + retry_policy_dict = {} + retry_policy = self._generate_retry_policy(retry_policy_dict) all_params = {key: params[key][0] for key in params} @@ -796,8 +805,8 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if worker.enabled: _default_task = self._make_task( task_id=task_id, status=PENDING, deps=deps, resources=resources, - priority=priority, family=family, module=module, - params=all_params, public_params=public_params, hidden_params=hidden_params + priority=priority, family=family, module=module, params=params, + accepts_messages=accepts_messages, ) else: _default_task = None @@ -824,7 +833,9 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if status == RUNNING and not task.worker_running: task.worker_running = worker_id if batch_id: - task.resources_running = self._state.get_batch_running_tasks(batch_id)[0].resources_running + # copy resources_running of the first batch task + batch_tasks = self._state.get_batch_running_tasks(batch_id) + task.resources_running = batch_tasks[0].resources_running.copy() task.time_running = time.time() if tracking_url is not None or task.status != RUNNING: @@ -845,9 +856,12 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, for batch_task in self._state.get_batch_running_tasks(task.batch_id): batch_task.expl = expl - if not (task.status in (RUNNING, BATCH_RUNNING) and (status not in (DONE, FAILED, RUNNING) or task.worker_running != worker_id)) or new_deps: + task_is_not_running = task.status not in (RUNNING, BATCH_RUNNING) + task_started_a_run = status in (DONE, FAILED, RUNNING) + running_on_this_worker = task.worker_running == worker_id + if task_is_not_running or (task_started_a_run and running_on_this_worker) or new_deps: # don't allow re-scheduling of task while it is running, it must either fail or succeed on the worker actually running it - if status == PENDING or status != task.status: + if status != task.status or status == PENDING: # Update the DB only if there was a acctual change, to prevent noise. # We also check for status == PENDING b/c that's the default value # (so checking for status != task.status woule lie) @@ -932,17 +946,48 @@ def disable_worker(self, worker): def set_worker_processes(self, worker, n): self._state.get_worker(worker).add_rpc_message('set_worker_processes', n=n) + @rpc_method() + def send_scheduler_message(self, worker, task, content): + if not self._config.send_messages: + return {"message_id": None} + + message_id = str(uuid.uuid4()) + self._state.get_worker(worker).add_rpc_message('dispatch_scheduler_message', task_id=task, + message_id=message_id, content=content) + + return {"message_id": message_id} + + @rpc_method() + def add_scheduler_message_response(self, task_id, message_id, response): + if self._state.has_task(task_id): + task = self._state.get_task(task_id) + task.scheduler_message_responses[message_id] = response + + @rpc_method() + def get_scheduler_message_response(self, task_id, message_id): + response = None + if self._state.has_task(task_id): + task = self._state.get_task(task_id) + response = task.scheduler_message_responses.pop(message_id, None) + return {"response": response} + + @rpc_method() + def is_pause_enabled(self): + return {'enabled': self._config.pause_enabled} + @rpc_method() def is_paused(self): return {'paused': self._paused} @rpc_method() def pause(self): - self._paused = True + if self._config.pause_enabled: + self._paused = True @rpc_method() def unpause(self): - self._paused = False + if self._config.pause_enabled: + self._paused = False @rpc_method() def update_resources(self, **resources): @@ -976,8 +1021,9 @@ def _used_resources(self): used_resources = collections.defaultdict(int) if self._resources is not None: for task in self._state.get_active_tasks_by_status(RUNNING): - if getattr(task, 'resources_running', task.resources): - for resource, amount in six.iteritems(getattr(task, 'resources_running', task.resources)): + resources_running = getattr(task, "resources_running", task.resources) + if resources_running: + for resource, amount in six.iteritems(resources_running): used_resources[resource] += amount return used_resources @@ -1181,7 +1227,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, elif best_task: self._state.set_status(best_task, RUNNING, self._config) best_task.worker_running = worker_id - best_task.resources_running = best_task.resources + best_task.resources_running = best_task.resources.copy() best_task.time_running = time.time() self._update_task_history(best_task, RUNNING, host=host) @@ -1244,14 +1290,17 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): 'name': task.family, 'priority': task.priority, 'resources': task.resources, + 'resources_running': getattr(task, "resources_running", None), 'tracking_url': getattr(task, "tracking_url", None), 'status_message': getattr(task, "status_message", None), - 'progress_percentage': getattr(task, "progress_percentage", None) + 'progress_percentage': getattr(task, "progress_percentage", None), } if task.status == DISABLED: ret['re_enable_able'] = task.scheduler_disable_time is not None if include_deps: ret['deps'] = list(task.deps if deps is None else deps) + if self._config.send_messages and task.status == RUNNING: + ret['accepts_messages'] = task.accepts_messages return ret @rpc_method() @@ -1528,6 +1577,31 @@ def get_task_progress_percentage(self, task_id): else: return {"taskId": task_id, "progressPercentage": None} + @rpc_method() + def decrease_running_task_resources(self, task_id, decrease_resources): + if self._state.has_task(task_id): + task = self._state.get_task(task_id) + if task.status != RUNNING: + return + + def decrease(resources, decrease_resources): + for resource, decrease_amount in six.iteritems(decrease_resources): + if decrease_amount > 0 and resource in resources: + resources[resource] = max(0, resources[resource] - decrease_amount) + + decrease(task.resources_running, decrease_resources) + if task.batch_id is not None: + for batch_task in self._state.get_batch_running_tasks(task.batch_id): + decrease(batch_task.resources_running, decrease_resources) + + @rpc_method() + def get_running_task_resources(self, task_id): + if self._state.has_task(task_id): + task = self._state.get_task(task_id) + return {"taskId": task_id, "resources": getattr(task, "resources_running", None)} + else: + return {"taskId": task_id, "resources": None} + def _update_task_history(self, task, status, host=None): try: if status == DONE or status == FAILED: diff --git a/luigi/static/visualiser/index.html b/luigi/static/visualiser/index.html index 03a52a9b9f..9e957f2149 100644 --- a/luigi/static/visualiser/index.html +++ b/luigi/static/visualiser/index.html @@ -33,13 +33,15 @@ {{#error}}{{/error}} {{#error}}{{/error}} + {{#re_enable}}{{/re_enable}} {{#re_enable}}Re-enable{{/re_enable}} {{#trackingUrl}}{{/trackingUrl}} - {{#statusMessage}}{{/statusMessage}} + {{#statusMessage}}{{/statusMessage}} {{^statusMessage}} - {{#progressPercentage}} + {{#progressPercentage}} {{/progressPercentage}} {{/statusMessage}} + {{#acceptsMessages}}{{/acceptsMessages}} + -