diff --git a/luigi/util.py b/luigi/util.py index 438ff385ee..adce164580 100644 --- a/luigi/util.py +++ b/luigi/util.py @@ -279,16 +279,19 @@ def run(self): # ... """ - def __init__(self, *tasks_to_inherit): + def __init__(self, *tasks_to_inherit, **kw_tasks_to_inherit): super(inherits, self).__init__() - if not tasks_to_inherit: - raise TypeError("tasks_to_inherit cannot be empty") - + if not tasks_to_inherit and not kw_tasks_to_inherit: + raise TypeError("tasks_to_inherit or kw_tasks_to_inherit must contain at least one task") + if tasks_to_inherit and kw_tasks_to_inherit: + raise TypeError("Only one of tasks_to_inherit or kw_tasks_to_inherit may be present") self.tasks_to_inherit = tasks_to_inherit + self.kw_tasks_to_inherit = kw_tasks_to_inherit def __call__(self, task_that_inherits): # Get all parameter objects from each of the underlying tasks - for task_to_inherit in self.tasks_to_inherit: + task_iterator = self.tasks_to_inherit or self.kw_tasks_to_inherit.values() + for task_to_inherit in task_iterator: for param_name, param_obj in task_to_inherit.get_params(): # Check if the parameter exists in the inheriting task if not hasattr(task_that_inherits, param_name): @@ -296,16 +299,27 @@ def __call__(self, task_that_inherits): setattr(task_that_inherits, param_name, param_obj) # Modify task_that_inherits by adding methods - def clone_parent(_self, **kwargs): - return _self.clone(cls=self.tasks_to_inherit[0], **kwargs) - task_that_inherits.clone_parent = clone_parent - def clone_parents(_self, **kwargs): - return [ - _self.clone(cls=task_to_inherit, **kwargs) - for task_to_inherit in self.tasks_to_inherit - ] - task_that_inherits.clone_parents = clone_parents + # Handle unnamed tasks as a list, named as a dictionary + if self.tasks_to_inherit: + def clone_parent(_self, **kwargs): + return _self.clone(cls=self.tasks_to_inherit[0], **kwargs) + task_that_inherits.clone_parent = clone_parent + + def clone_parents(_self, **kwargs): + return [ + _self.clone(cls=task_to_inherit, **kwargs) + for task_to_inherit in self.tasks_to_inherit + ] + task_that_inherits.clone_parents = clone_parents + elif self.kw_tasks_to_inherit: + # Even if there is just one named task, return a dictionary + def clone_parents(_self, **kwargs): + return { + task_name: _self.clone(cls=task_to_inherit, **kwargs) + for task_name, task_to_inherit in self.kw_tasks_to_inherit.items() + } + task_that_inherits.clone_parents = clone_parents return task_that_inherits @@ -318,15 +332,14 @@ class requires: """ - def __init__(self, *tasks_to_require): + def __init__(self, *tasks_to_require, **kw_tasks_to_require): super(requires, self).__init__() - if not tasks_to_require: - raise TypeError("tasks_to_require cannot be empty") self.tasks_to_require = tasks_to_require + self.kw_tasks_to_require = kw_tasks_to_require def __call__(self, task_that_requires): - task_that_requires = inherits(*self.tasks_to_require)(task_that_requires) + task_that_requires = inherits(*self.tasks_to_require, **self.kw_tasks_to_require)(task_that_requires) # Modify task_that_requires by adding requires method. # If only one task is required, this single task is returned. diff --git a/test/util_test.py b/test/util_test.py index c2e3bb49c1..84ed62d57f 100644 --- a/test/util_test.py +++ b/test/util_test.py @@ -54,6 +54,20 @@ def requires(self): self.assertEqual(str(child_task), 'blah.ChildTask(my_param=hello)') self.assertIn(ParentTask(my_param='hello'), luigi.task.flatten(child_task.requires())) + def test_task_ids_using_inherits_kwargs(self): + class ParentTask(luigi.Task): + my_param = luigi.Parameter() + luigi.namespace('blah') + + @inherits(parent=ParentTask) + class ChildTask(luigi.Task): + def requires(self): + return self.clone(ParentTask) + luigi.namespace('') + child_task = ChildTask(my_param='hello') + self.assertEqual(str(child_task), 'blah.ChildTask(my_param=hello)') + self.assertIn(ParentTask(my_param='hello'), luigi.task.flatten(child_task.requires())) + def _setup_parent_and_child_inherits(self): class ParentTask(luigi.Task): my_parameter = luigi.Parameter() @@ -174,3 +188,18 @@ def test_requires_has_effect_MRO(self): ChildTask = self._setup_requires_inheritence() self.assertNotEqual(str(ChildTask.__mro__[0]), str(ChildTask.__mro__[1])) + + def test_kwargs_requires_gives_named_inputs(self): + class ParentTask(RunOnceTask): + def output(self): + return "Target" + + @requires(parent_1=ParentTask, parent_2=ParentTask) + class ChildTask(RunOnceTask): + resulting_input = 'notset' + + def run(self): + self.__class__.resulting_input = self.input() + + self.assertTrue(self.run_locally_split('ChildTask')) + self.assertEqual(ChildTask.resulting_input, {'parent_1': 'Target', 'parent_2': 'Target'})