Skip to content

Commit

Permalink
Forward cache workflow between branches.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Feb 22, 2023
1 parent 221306a commit 2f99321
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 33 deletions.
12 changes: 4 additions & 8 deletions law/contrib/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,14 @@ def is_root(self):
def is_leaf(self):
return not self.is_forest() and self.tree_depth == self.max_tree_depth

def _create_workflow_task(self):
def req_workflow(self, **kwargs):
# since the forest counts as a branch, as_workflow should point the tree_index 0
# which is only used to compute the overall merge tree
if self.is_forest():
return self._req_tree(
self,
branch=-1,
tree_index=0,
_exclude=self.exclude_params_workflow,
)
kwargs["tree_index"] = 0
kwargs["_skip_task_excludes"] = False

return super(ForestMerge, self)._create_workflow_task()
return super(ForestMerge, self).req_workflow(**kwargs)

@property
def max_tree_depth(self):
Expand Down
31 changes: 30 additions & 1 deletion law/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __new__(metacls, classname, bases, classdict):

class BaseTask(six.with_metaclass(BaseRegister, luigi.Task)):

# whether to cache the result of requires() for input() and potentially also other calls
cache_requirements = False

exclude_index = True
exclude_params_index = set()
exclude_params_req = set()
Expand Down Expand Up @@ -229,6 +232,9 @@ def __init__(self, *args, **kwargs):
# task level logger, created lazily
self._task_logger = None

# attribute for cached requirements if enabled
self._cached_requirements = no_value

def complete(self):
outputs = [t for t in flatten(self.output()) if not t.optional]

Expand All @@ -239,6 +245,19 @@ def complete(self):

return all(t.exists() for t in outputs)

def input(self):
# get potentially cached requirements
if self.cache_requirements:
if self._cached_requirements is no_value:
self._cached_requirements = self.requires()
else:
print("BASETASK.INPUT() TAKING REQS FROM CACHE BITCHES")
reqs = self._cached_requirements
else:
reqs = self.requires()

return luigi.task.getpaths(reqs)

@abstractmethod
def run(self):
return
Expand Down Expand Up @@ -690,7 +709,17 @@ def _repr_flags(self):
return super(WrapperTask, self)._repr_flags() + ["wrapper"]

def complete(self):
return all(task.complete() for task in flatten(self.requires()))
# get potentially cached requirements
if self.cache_requirements:
if self._cached_requirements is no_value:
self._cached_requirements = self.requires()
else:
print("WRAPPER.COMPLETE() TAKING REQS FROM CACHE BITCHES")
reqs = self._cached_requirements
else:
reqs = self.requires()

return all(task.complete() for task in flatten(reqs))

def output(self):
return self.input()
Expand Down
70 changes: 46 additions & 24 deletions law/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class BaseWorkflow(six.with_metaclass(WorkflowRegister, Task)):
reset_branch_map_before_run = False
create_branch_map_before_repr = False
workflow_run_decorators = None
cache_workflow_requirements = False

# accessible properties
workflow_property = None
Expand Down Expand Up @@ -437,6 +438,9 @@ def __init__(self, *args, **kwargs):
self._workflow_cls = None
self._workflow_proxy = None

# attribute for cached requirements if enabled
self._cached_workflow_requirements = no_value

def _initialize_workflow(self, force=False):
if self._workflow_initialized and not force:
return
Expand Down Expand Up @@ -493,6 +497,34 @@ def _repr_params(self, *args, **kwargs):

return params

def req_branch(self, branch, **kwargs):
if branch == -1:
raise ValueError(
"branch must not be -1 when creating a new branch task via req_branch(), "
"but got {}".format(branch),
)

# default kwargs
kwargs.setdefault("_skip_task_excludes", True)
if self.is_workflow():
kwargs.setdefault("_exclude", self.exclude_params_workflow)

# create the task
task = self.req(self, branch=branch, **kwargs)

# set the _workflow_task attribute if known
task._workflow_task = self if self.is_workflow() else self._workflow_task

return task

def req_workflow(self, **kwargs):
# default kwargs
kwargs.setdefault("_skip_task_excludes", True)
if self.is_branch():
kwargs.setdefault("_exclude", self._exclude_params_workflow)

return self.req(self, branch=-1, **kwargs)

def is_branch(self):
"""
Returns whether or not this task refers to a *branch*.
Expand All @@ -516,17 +548,10 @@ def as_branch(self, branch=None):
if branch == -1:
raise ValueError("branch must not be -1 when selecting a branch task")

if self.is_branch():
if branch is None or branch == self.branch:
return self
return self.req(self, branch=branch, _skip_task_excludes=True)

return self.req(
self,
branch=branch or 0,
_exclude=self.exclude_params_branch,
_skip_task_excludes=True,
)
if self.is_branch() and branch in (None, self.branch):
return self

return self.req_branch(branch or 0)

def as_workflow(self):
"""
Expand All @@ -537,21 +562,10 @@ def as_workflow(self):
return self

if self._workflow_task is None:
self._workflow_task = self._create_workflow_task()
self._workflow_task = self.req_workflow()

return self._workflow_task

def _create_workflow_task(self):
"""
Implements how the workflow task is created as used internally by :py:meth:`as_workflow`.
"""
return self.req(
self,
branch=-1,
_exclude=self.exclude_params_workflow,
_skip_task_excludes=True,
)

@abstractmethod
def create_branch_map(self):
"""
Expand Down Expand Up @@ -794,7 +808,15 @@ def workflow_input(self):
if self.is_branch():
raise Exception("calls to workflow_input are forbidden for branch tasks")

return luigi.task.getpaths(self.workflow_proxy.requires())
# get potentially cached workflow requirements
if self.cache_workflow_requirements:
if self._cached_workflow_requirements is no_value:
self._cached_workflow_requirements = self.workflow_proxy.requires()
reqs = self._cached_workflow_requirements
else:
reqs = self.workflow_proxy.requires()

return luigi.task.getpaths(reqs)

def requires_from_branch(self):
"""
Expand Down

0 comments on commit 2f99321

Please sign in to comment.