Skip to content

Commit

Permalink
Improve luigi tools
Browse files Browse the repository at this point in the history
* Add several utilities to apply a function on luigi inputs and outputs.
* Add a ForceableTask class which is a luigi task that can be forced
running again by setting the 'rerun' parameter to True.
* Add a WorkflowTask which is both a GlobalParamTask and ForceableTask.

Change-Id: Idf2b45094048e07671997a2a0ce6713ababee6b5
  • Loading branch information
adrien-berchet committed Oct 20, 2020
1 parent a853dd5 commit e8eae3b
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ max-line-length=100
max-args=10
# Argument names that match this expression will be ignored. Default to name
# with leading underscore
ignored-argument-names=_.*
ignored-argument-names=args|kwargs|_.*
# Maximum number of locals for function / method body
max-locals=25
# Maximum number of return / yield for function / method body
Expand Down
100 changes: 86 additions & 14 deletions synthesis_workflow/tasks/luigi_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,77 @@
L = logging.getLogger(__name__)


@luigi.Task.event_handler(luigi.Event.SUCCESS)
def log_targets(task):
"""Hook to log output target of the task"""
class_name = task.__class__.__name__
def recursive_check(task, attr="rerun"):
"""Check if a task or any of its recursive dependencies has a given attribute set to True"""
val = getattr(task, attr, False)

for dep in task.deps():
val = val or getattr(dep, attr, False) or recursive_check(dep, attr)

return val


def target_remove(target, *args, **kwargs):
"""Remove a given target by calling its 'exists()' and 'remove()' methods"""
try:
output = task.output()
if target.exists():
target.remove()
except AttributeError as e:
raise AttributeError(
"The target must have 'exists()' and 'remove()' methods"
) from e


def apply_over_luigi_iterable(luigi_iterable, func):
"""Apply the given function to a luigi iterable (task.input() or task.output())"""
try:
for key, i in luigi_iterable.items():
func(i, key)
except AttributeError:
for i in luigi.task.flatten(luigi_iterable):
func(i)


def apply_over_inputs(task, func):
"""Apply the given function to all inputs of a luigi task.
The given function should accept the following arguments:
* luigi_iterable: the inputs or outputs of the task
* key=None: the key when the iterable is a dictionnary
"""
try:
inputs = task.input()
except AttributeError:
return

apply_over_luigi_iterable(inputs, func)


def apply_over_outputs(task, func):
"""Apply the given function to all outputs of a luigi task.
The given function should accept the following arguments:
* luigi_iterable: the inputs or outputs of the task
* key=None: the key when the iterable is a dictionnary
"""
try:
L.debug("Output of %s task: %s", class_name, output.path)
outputs = task.output()
except AttributeError:
try:
for k, i in output.items():
L.debug("Output %s of %s task: %s", k, class_name, i.path)
except AttributeError:
for i in output:
L.debug("Output of %s task: %s", class_name, i.path)
return

apply_over_luigi_iterable(outputs, func)


@luigi.Task.event_handler(luigi.Event.SUCCESS)
def log_targets(task):
"""Hook to log output target of the task"""

def log_func(target, key=None):
class_name = task.__class__.__name__
if key is None:
L.debug("Output of %s task: %s", class_name, target.path)
else:
L.debug("Output %s of %s task: %s", key, class_name, target.path)

apply_over_outputs(task, log_func)


@luigi.Task.event_handler(luigi.Event.START)
Expand All @@ -40,8 +94,20 @@ def log_parameters(task):
L.debug("Can't print '%s' attribute for unknown reason", name)


class ForceableTask(luigi.Task):
"""A luigi task that can be forced running again by setting the 'rerun' parameter to True."""

rerun = luigi.BoolParameter(significant=False, default=False)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if recursive_check(self):
apply_over_outputs(self, target_remove)


class GlobalParamTask(luigi.Task):
"""Base class used to add customisable global parameters"""
"""Mixin used to add customisable global parameters"""

def __getattribute__(self, name):
tmp = super().__getattribute__(name)
Expand All @@ -64,7 +130,13 @@ def __setattr__(self, name, value):
return super().__setattr__(name, value)


class BaseWrapperTask(GlobalParamTask, luigi.WrapperTask):
class WorkflowTask(GlobalParamTask, ForceableTask):
"""Default task used in workflows
This task can be forced running again by setting the 'rerun' parameter to True.
It can also use copy and link parameters from other tasks."""


class WorkflowWrapperTask(WorkflowTask, luigi.WrapperTask):
"""Base wrapper class with global parameters"""


Expand Down
9 changes: 6 additions & 3 deletions synthesis_workflow/tasks/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pandas as pd

from ..validation import plot_morphometrics
from .luigi_tools import WorkflowTask
from .luigi_tools import WorkflowWrapperTask
from .synthesis import ApplySubstitutionRules
from .validation import PlotCollage
from .validation import PlotDensityProfiles
Expand All @@ -11,7 +13,7 @@
from .vacuum_synthesis import PlotVacuumMorphologies


class ValidateSynthesis(luigi.WrapperTask):
class ValidateSynthesis(WorkflowWrapperTask):
"""Workflow to validate synthesis"""

with_collage = luigi.BoolParameter(default=True)
Expand All @@ -33,7 +35,7 @@ def requires(self):
return tasks


class ValidateVacuumSynthesis(luigi.WrapperTask):
class ValidateVacuumSynthesis(WorkflowWrapperTask):
"""Workflow to validate vacuum synthesis"""

with_vacuum_morphologies = luigi.BoolParameter(default=True)
Expand All @@ -58,7 +60,7 @@ def requires(self):
return tasks


class ValidateRescaling(luigi.Task):
class ValidateRescaling(WorkflowTask):
"""Workflow to validate rescaling"""

morphometrics_path = luigi.Parameter(default="morphometrics")
Expand All @@ -71,6 +73,7 @@ class ValidateRescaling(luigi.Task):

def requires(self):
""""""
# pylint: disable=no-self-use
return ApplySubstitutionRules()

def run(self):
Expand Down
File renamed without changes.
Loading

0 comments on commit e8eae3b

Please sign in to comment.