From 2f993219c1f11e5ef27df1d3e2fe29c0562ed73f Mon Sep 17 00:00:00 2001 From: Marcel R Date: Wed, 22 Feb 2023 12:40:43 +0100 Subject: [PATCH] Forward cache workflow between branches. --- law/contrib/tasks/__init__.py | 12 ++---- law/task/base.py | 31 +++++++++++++++- law/workflow/base.py | 70 +++++++++++++++++++++++------------ 3 files changed, 80 insertions(+), 33 deletions(-) diff --git a/law/contrib/tasks/__init__.py b/law/contrib/tasks/__init__.py index a551129d..fabb8c6b 100644 --- a/law/contrib/tasks/__init__.py +++ b/law/contrib/tasks/__init__.py @@ -224,18 +224,14 @@ def is_root(self): def is_leaf(self): return not self.is_forest() and self.tree_depth == self.max_tree_depth - def _create_workflow_task(self): + def req_workflow(self, **kwargs): # since the forest counts as a branch, as_workflow should point the tree_index 0 # which is only used to compute the overall merge tree if self.is_forest(): - return self._req_tree( - self, - branch=-1, - tree_index=0, - _exclude=self.exclude_params_workflow, - ) + kwargs["tree_index"] = 0 + kwargs["_skip_task_excludes"] = False - return super(ForestMerge, self)._create_workflow_task() + return super(ForestMerge, self).req_workflow(**kwargs) @property def max_tree_depth(self): diff --git a/law/task/base.py b/law/task/base.py index 3e0e5737..19dbf37f 100644 --- a/law/task/base.py +++ b/law/task/base.py @@ -77,6 +77,9 @@ def __new__(metacls, classname, bases, classdict): class BaseTask(six.with_metaclass(BaseRegister, luigi.Task)): + # whether to cache the result of requires() for input() and potentially also other calls + cache_requirements = False + exclude_index = True exclude_params_index = set() exclude_params_req = set() @@ -229,6 +232,9 @@ def __init__(self, *args, **kwargs): # task level logger, created lazily self._task_logger = None + # attribute for cached requirements if enabled + self._cached_requirements = no_value + def complete(self): outputs = [t for t in flatten(self.output()) if not t.optional] @@ -239,6 +245,19 @@ def complete(self): return all(t.exists() for t in outputs) + def input(self): + # get potentially cached requirements + if self.cache_requirements: + if self._cached_requirements is no_value: + self._cached_requirements = self.requires() + else: + print("BASETASK.INPUT() TAKING REQS FROM CACHE BITCHES") + reqs = self._cached_requirements + else: + reqs = self.requires() + + return luigi.task.getpaths(reqs) + @abstractmethod def run(self): return @@ -690,7 +709,17 @@ def _repr_flags(self): return super(WrapperTask, self)._repr_flags() + ["wrapper"] def complete(self): - return all(task.complete() for task in flatten(self.requires())) + # get potentially cached requirements + if self.cache_requirements: + if self._cached_requirements is no_value: + self._cached_requirements = self.requires() + else: + print("WRAPPER.COMPLETE() TAKING REQS FROM CACHE BITCHES") + reqs = self._cached_requirements + else: + reqs = self.requires() + + return all(task.complete() for task in flatten(reqs)) def output(self): return self.input() diff --git a/law/workflow/base.py b/law/workflow/base.py index 2c92f25b..77b6bef4 100644 --- a/law/workflow/base.py +++ b/law/workflow/base.py @@ -377,6 +377,7 @@ class BaseWorkflow(six.with_metaclass(WorkflowRegister, Task)): reset_branch_map_before_run = False create_branch_map_before_repr = False workflow_run_decorators = None + cache_workflow_requirements = False # accessible properties workflow_property = None @@ -437,6 +438,9 @@ def __init__(self, *args, **kwargs): self._workflow_cls = None self._workflow_proxy = None + # attribute for cached requirements if enabled + self._cached_workflow_requirements = no_value + def _initialize_workflow(self, force=False): if self._workflow_initialized and not force: return @@ -493,6 +497,34 @@ def _repr_params(self, *args, **kwargs): return params + def req_branch(self, branch, **kwargs): + if branch == -1: + raise ValueError( + "branch must not be -1 when creating a new branch task via req_branch(), " + "but got {}".format(branch), + ) + + # default kwargs + kwargs.setdefault("_skip_task_excludes", True) + if self.is_workflow(): + kwargs.setdefault("_exclude", self.exclude_params_workflow) + + # create the task + task = self.req(self, branch=branch, **kwargs) + + # set the _workflow_task attribute if known + task._workflow_task = self if self.is_workflow() else self._workflow_task + + return task + + def req_workflow(self, **kwargs): + # default kwargs + kwargs.setdefault("_skip_task_excludes", True) + if self.is_branch(): + kwargs.setdefault("_exclude", self._exclude_params_workflow) + + return self.req(self, branch=-1, **kwargs) + def is_branch(self): """ Returns whether or not this task refers to a *branch*. @@ -516,17 +548,10 @@ def as_branch(self, branch=None): if branch == -1: raise ValueError("branch must not be -1 when selecting a branch task") - if self.is_branch(): - if branch is None or branch == self.branch: - return self - return self.req(self, branch=branch, _skip_task_excludes=True) - - return self.req( - self, - branch=branch or 0, - _exclude=self.exclude_params_branch, - _skip_task_excludes=True, - ) + if self.is_branch() and branch in (None, self.branch): + return self + + return self.req_branch(branch or 0) def as_workflow(self): """ @@ -537,21 +562,10 @@ def as_workflow(self): return self if self._workflow_task is None: - self._workflow_task = self._create_workflow_task() + self._workflow_task = self.req_workflow() return self._workflow_task - def _create_workflow_task(self): - """ - Implements how the workflow task is created as used internally by :py:meth:`as_workflow`. - """ - return self.req( - self, - branch=-1, - _exclude=self.exclude_params_workflow, - _skip_task_excludes=True, - ) - @abstractmethod def create_branch_map(self): """ @@ -794,7 +808,15 @@ def workflow_input(self): if self.is_branch(): raise Exception("calls to workflow_input are forbidden for branch tasks") - return luigi.task.getpaths(self.workflow_proxy.requires()) + # get potentially cached workflow requirements + if self.cache_workflow_requirements: + if self._cached_workflow_requirements is no_value: + self._cached_workflow_requirements = self.workflow_proxy.requires() + reqs = self._cached_workflow_requirements + else: + reqs = self.workflow_proxy.requires() + + return luigi.task.getpaths(reqs) def requires_from_branch(self): """