diff --git a/luigi/util.py b/luigi/util.py index f6be20021b..784a778602 100644 --- a/luigi/util.py +++ b/luigi/util.py @@ -52,7 +52,7 @@ def requires(self): more burdensome than the last. Refactoring becomes more difficult. There are several ways one might try and avoid the problem. -**Approach 1**: Parameters via command line or config instead of ``requires``. +**Approach 1**: Parameters via command line or config instead of :func:`~luigi.task.Task.requires`. .. code-block:: python @@ -132,13 +132,13 @@ def requires(self): specified in the wrong order. This contrived example is easy to fix (by swapping the ordering of the parents of ``TaskA``), but real world cases can be more difficult to both spot and fix. Inheriting from multiple classes -derived from ``luigi.Task`` should be undertaken with caution and avoided +derived from :class:`~luigi.task.Task` should be undertaken with caution and avoided where possible. -**Approach 3**: Use ``inherits`` and ``requires`` +**Approach 3**: Use :class:`~luigi.util.inherits` and :class:`~luigi.util.requires` -The ``inherits`` class decorator in this module copies parameters (and +The :class:`~luigi.util.inherits` class decorator in this module copies parameters (and nothing else) from one task class to another, and avoids direct pythonic inheritance. @@ -185,11 +185,12 @@ def requires(self): issues, and keeps the task command line interface as simple (as it can be, anyway). Refactoring task parameters is also much easier. -The ``requires`` helper function can reduce this pattern even further. It -does everything ``inherits`` does, and also attaches a ``requires`` method +The :class:`~luigi.util.requires` helper function can reduce this pattern even further. It +does everything :class:`~luigi.util.inherits` does, +and also attaches a :class:`~luigi.util.requires` method to your task (still all without pythonic inheritance). -But how does it know how to invoke the upstream task? It uses ``clone`` +But how does it know how to invoke the upstream task? It uses :func:`~luigi.task.Task.clone` behind the scenes! .. code-block:: python @@ -251,59 +252,91 @@ class inherits(object): """ Task inheritance. + *New after Luigi 2.7.6:* multiple arguments support. + Usage: .. code-block:: python class AnotherTask(luigi.Task): + m = luigi.IntParameter() + + class YetAnotherTask(luigi.Task): n = luigi.IntParameter() - # ... @inherits(AnotherTask): - class MyTask(luigi.Task): + class MyFirstTask(luigi.Task): def requires(self): return self.clone_parent() + def run(self): + print self.m # this will be defined + # ... + + @inherits(AnotherTask, YetAnotherTask): + class MySecondTask(luigi.Task): + def requires(self): + return self.clone_parents() + def run(self): print self.n # this will be defined # ... """ - def __init__(self, task_to_inherit): + def __init__(self, *tasks_to_inherit): super(inherits, self).__init__() - self.task_to_inherit = task_to_inherit + if not tasks_to_inherit: + raise TypeError("tasks_to_inherit cannot be empty") + + self.tasks_to_inherit = tasks_to_inherit def __call__(self, task_that_inherits): - # Get all parameter objects from the underlying task - for param_name, param_obj in self.task_to_inherit.get_params(): - # Check if the parameter exists in the inheriting task - if not hasattr(task_that_inherits, param_name): - # If not, add it to the inheriting task - setattr(task_that_inherits, param_name, param_obj) + # Get all parameter objects from each of the underlying tasks + for task_to_inherit in self.tasks_to_inherit: + for param_name, param_obj in task_to_inherit.get_params(): + # Check if the parameter exists in the inheriting task + if not hasattr(task_that_inherits, param_name): + # If not, add it to the inheriting task + setattr(task_that_inherits, param_name, param_obj) # Modify task_that_inherits by adding methods - def clone_parent(_self, **args): - return _self.clone(cls=self.task_to_inherit, **args) + def clone_parent(_self, **kwargs): + return _self.clone(cls=self.tasks_to_inherit[0], **kwargs) task_that_inherits.clone_parent = clone_parent + def clone_parents(_self, **kwargs): + return [ + _self.clone(cls=task_to_inherit, **kwargs) + for task_to_inherit in self.tasks_to_inherit + ] + task_that_inherits.clone_parents = clone_parents + return task_that_inherits class requires(object): """ - Same as @inherits, but also auto-defines the requires method. + Same as :class:`~luigi.util.inherits`, but also auto-defines the requires method. + + *New after Luigi 2.7.6:* multiple arguments support. + """ - def __init__(self, task_to_require): + def __init__(self, *tasks_to_require): super(requires, self).__init__() - self.inherit_decorator = inherits(task_to_require) + if not tasks_to_require: + raise TypeError("tasks_to_require cannot be empty") + + self.tasks_to_require = tasks_to_require def __call__(self, task_that_requires): - task_that_requires = self.inherit_decorator(task_that_requires) + task_that_requires = inherits(*self.tasks_to_require)(task_that_requires) - # Modify task_that_requres by adding methods + # Modify task_that_requires by adding requires method. + # If only one task is required, this single task is returned. + # Otherwise, list of tasks is returned def requires(_self): - return _self.clone_parent() + return _self.clone_parent() if len(self.tasks_to_require) == 1 else _self.clone_parents() task_that_requires.requires = requires return task_that_requires diff --git a/test/decorator_test.py b/test/decorator_test.py index 0e113caaf6..e9851a269e 100644 --- a/test/decorator_test.py +++ b/test/decorator_test.py @@ -53,9 +53,14 @@ class D_null(luigi.Task): param1 = None +@inherits(A, B) +class E(luigi.Task): + param4 = luigi.Parameter("class E-specific default") + + @inherits(A) @inherits(B) -class E(luigi.Task): +class E_stacked(luigi.Task): param4 = luigi.Parameter("class E-specific default") @@ -69,6 +74,7 @@ def setUp(self): self.d = D() self.d_null = D_null() self.e = E() + self.e_stacked = E_stacked() def test_has_param(self): b_params = dict(self.b.get_params()).keys() @@ -91,11 +97,22 @@ def test_overwriting_defaults(self): self.assertNotEqual(self.d.param1, self.a.param1) self.assertEqual(self.d.param1, "class D overwriting class A's default") - def test_stacked_inheritance(self): + def test_multiple_inheritance(self): self.assertEqual(self.e.param1, self.a.param1) self.assertEqual(self.e.param1, self.b.param1) self.assertEqual(self.e.param2, self.b.param2) + def test_stacked_inheritance(self): + self.assertEqual(self.e_stacked.param1, self.a.param1) + self.assertEqual(self.e_stacked.param1, self.b.param1) + self.assertEqual(self.e_stacked.param2, self.b.param2) + + def test_empty_inheritance(self): + with self.assertRaises(TypeError): + @inherits() + class shouldfail(luigi.Task): + pass + def test_removing_parameter(self): self.assertFalse("param1" in dict(self.d_null.get_params()).keys()) @@ -226,53 +243,75 @@ def test_wrong_common_params_order(self): self.assertRaises(TypeError, self.k_wrongparamsorder.requires) -class X(luigi.Task): +class V(luigi.Task): n = luigi.IntParameter(default=42) -@inherits(X) -class Y(luigi.Task): +@inherits(V) +class W(luigi.Task): def requires(self): return self.clone_parent() -@requires(X) -class Y2(luigi.Task): +@requires(V) +class W2(luigi.Task): pass -@requires(X) -class Y3(luigi.Task): +@requires(V) +class W3(luigi.Task): n = luigi.IntParameter(default=43) +class X(luigi.Task): + m = luigi.IntParameter(default=56) + + +@requires(V, X) +class Y(luigi.Task): + pass + + class CloneParentTest(unittest.TestCase): def test_clone_parent(self): - y = Y() - x = X() - self.assertEqual(y.requires(), x) - self.assertEqual(y.n, 42) + w = W() + v = V() + self.assertEqual(w.requires(), v) + self.assertEqual(w.n, 42) def test_requires(self): - y2 = Y2() - x = X() - self.assertEqual(y2.requires(), x) - self.assertEqual(y2.n, 42) + w2 = W2() + v = V() + self.assertEqual(w2.requires(), v) + self.assertEqual(w2.n, 42) def test_requires_override_default(self): - y3 = Y3() + w3 = W3() + v = V() + self.assertNotEqual(w3.requires(), v) + self.assertEqual(w3.n, 43) + self.assertEqual(w3.requires().n, 43) + + def test_multiple_requires(self): + y = Y() + v = V() x = X() - self.assertNotEqual(y3.requires(), x) - self.assertEqual(y3.n, 43) - self.assertEqual(y3.requires().n, 43) + self.assertEqual(y.requires()[0], v) + self.assertEqual(y.requires()[1], x) + + def test_empty_requires(self): + with self.assertRaises(TypeError): + @requires() + class shouldfail(luigi.Task): + pass def test_names(self): # Just make sure the decorators retain the original class names - x = X() - self.assertEqual(str(x), 'X(n=42)') - self.assertEqual(x.__class__.__name__, 'X') + v = V() + self.assertEqual(str(v), 'V(n=42)') + self.assertEqual(v.__class__.__name__, 'V') class P(luigi.Task):