Skip to content

Commit

Permalink
Add support for naming tasks in @requires (#3077)
Browse files Browse the repository at this point in the history
* Add support for naming tasks in @requires

When using @requires the requires method is auto-generated. However, as
it just takes a list the tasks run method needs to identify which input
is which.

This adds support for named requirements using luigis existing support
for returning a dictionary from the requires function.

Usage:

    class Parent1(luigi.Task):
        ...
    class Parent2(luigi.Task):
        ...

    @requires(first_parent=Parent1, second_parent=Parent2)
    class Child(luigi.Task):
        def run(self):
           first_parent_target = self.input()["first_parent"]
           second_parent_target = self.input()["second_parent"]

* util_test flake8 fixes
  • Loading branch information
IanCal authored Aug 28, 2021
1 parent 5845d81 commit 00aa83a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 18 deletions.
49 changes: 31 additions & 18 deletions luigi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,33 +279,47 @@ def run(self):
# ...
"""

def __init__(self, *tasks_to_inherit):
def __init__(self, *tasks_to_inherit, **kw_tasks_to_inherit):
super(inherits, self).__init__()
if not tasks_to_inherit:
raise TypeError("tasks_to_inherit cannot be empty")

if not tasks_to_inherit and not kw_tasks_to_inherit:
raise TypeError("tasks_to_inherit or kw_tasks_to_inherit must contain at least one task")
if tasks_to_inherit and kw_tasks_to_inherit:
raise TypeError("Only one of tasks_to_inherit or kw_tasks_to_inherit may be present")
self.tasks_to_inherit = tasks_to_inherit
self.kw_tasks_to_inherit = kw_tasks_to_inherit

def __call__(self, task_that_inherits):
# Get all parameter objects from each of the underlying tasks
for task_to_inherit in self.tasks_to_inherit:
task_iterator = self.tasks_to_inherit or self.kw_tasks_to_inherit.values()
for task_to_inherit in task_iterator:
for param_name, param_obj in task_to_inherit.get_params():
# Check if the parameter exists in the inheriting task
if not hasattr(task_that_inherits, param_name):
# If not, add it to the inheriting task
setattr(task_that_inherits, param_name, param_obj)

# Modify task_that_inherits by adding methods
def clone_parent(_self, **kwargs):
return _self.clone(cls=self.tasks_to_inherit[0], **kwargs)
task_that_inherits.clone_parent = clone_parent

def clone_parents(_self, **kwargs):
return [
_self.clone(cls=task_to_inherit, **kwargs)
for task_to_inherit in self.tasks_to_inherit
]
task_that_inherits.clone_parents = clone_parents
# Handle unnamed tasks as a list, named as a dictionary
if self.tasks_to_inherit:
def clone_parent(_self, **kwargs):
return _self.clone(cls=self.tasks_to_inherit[0], **kwargs)
task_that_inherits.clone_parent = clone_parent

def clone_parents(_self, **kwargs):
return [
_self.clone(cls=task_to_inherit, **kwargs)
for task_to_inherit in self.tasks_to_inherit
]
task_that_inherits.clone_parents = clone_parents
elif self.kw_tasks_to_inherit:
# Even if there is just one named task, return a dictionary
def clone_parents(_self, **kwargs):
return {
task_name: _self.clone(cls=task_to_inherit, **kwargs)
for task_name, task_to_inherit in self.kw_tasks_to_inherit.items()
}
task_that_inherits.clone_parents = clone_parents

return task_that_inherits

Expand All @@ -318,15 +332,14 @@ class requires:
"""

def __init__(self, *tasks_to_require):
def __init__(self, *tasks_to_require, **kw_tasks_to_require):
super(requires, self).__init__()
if not tasks_to_require:
raise TypeError("tasks_to_require cannot be empty")

self.tasks_to_require = tasks_to_require
self.kw_tasks_to_require = kw_tasks_to_require

def __call__(self, task_that_requires):
task_that_requires = inherits(*self.tasks_to_require)(task_that_requires)
task_that_requires = inherits(*self.tasks_to_require, **self.kw_tasks_to_require)(task_that_requires)

# Modify task_that_requires by adding requires method.
# If only one task is required, this single task is returned.
Expand Down
29 changes: 29 additions & 0 deletions test/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def requires(self):
self.assertEqual(str(child_task), 'blah.ChildTask(my_param=hello)')
self.assertIn(ParentTask(my_param='hello'), luigi.task.flatten(child_task.requires()))

def test_task_ids_using_inherits_kwargs(self):
class ParentTask(luigi.Task):
my_param = luigi.Parameter()
luigi.namespace('blah')

@inherits(parent=ParentTask)
class ChildTask(luigi.Task):
def requires(self):
return self.clone(ParentTask)
luigi.namespace('')
child_task = ChildTask(my_param='hello')
self.assertEqual(str(child_task), 'blah.ChildTask(my_param=hello)')
self.assertIn(ParentTask(my_param='hello'), luigi.task.flatten(child_task.requires()))

def _setup_parent_and_child_inherits(self):
class ParentTask(luigi.Task):
my_parameter = luigi.Parameter()
Expand Down Expand Up @@ -174,3 +188,18 @@ def test_requires_has_effect_MRO(self):
ChildTask = self._setup_requires_inheritence()
self.assertNotEqual(str(ChildTask.__mro__[0]),
str(ChildTask.__mro__[1]))

def test_kwargs_requires_gives_named_inputs(self):
class ParentTask(RunOnceTask):
def output(self):
return "Target"

@requires(parent_1=ParentTask, parent_2=ParentTask)
class ChildTask(RunOnceTask):
resulting_input = 'notset'

def run(self):
self.__class__.resulting_input = self.input()

self.assertTrue(self.run_locally_split('ChildTask'))
self.assertEqual(ChildTask.resulting_input, {'parent_1': 'Target', 'parent_2': 'Target'})

0 comments on commit 00aa83a

Please sign in to comment.