Skip to content

Commit

Permalink
Add dynamic_workflow_condition.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Oct 7, 2023
1 parent 6fb41c1 commit 5173f91
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docs/api/workflow/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ Class ``BaseWorkflowProxy``
Functions
---------

.. autofunction:: workflow_property
.. autofunction:: dynamic_workflow_condition

.. autofunction:: cached_workflow_property
.. autofunction:: workflow_property
5 changes: 4 additions & 1 deletion law/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"Register", "Task", "WrapperTask", "ExternalTask",
"SandboxTask",
"BaseWorkflow", "WorkflowParameter", "LocalWorkflow", "workflow_property",
"dynamic_workflow_condition",
"FileSystemTarget", "FileSystemFileTarget", "FileSystemDirectoryTarget",
"LocalFileSystem", "LocalTarget", "LocalFileTarget", "LocalDirectoryTarget",
"TargetCollection", "FileCollection", "SiblingFileCollection", "NestedSiblingFileCollection",
Expand Down Expand Up @@ -76,7 +77,9 @@
)
import law.decorator
from law.task.base import Register, Task, WrapperTask, ExternalTask
from law.workflow.base import BaseWorkflow, WorkflowParameter, workflow_property
from law.workflow.base import (
BaseWorkflow, WorkflowParameter, workflow_property, dynamic_workflow_condition,
)
from law.workflow.local import LocalWorkflow
from law.sandbox.base import Sandbox, SandboxTask
from law.sandbox.bash import BashSandbox
Expand Down
243 changes: 236 additions & 7 deletions law/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Workflow and workflow proxy base class definitions.
"""

__all__ = ["BaseWorkflow", "WorkflowParameter", "workflow_property"]
__all__ = ["BaseWorkflow", "WorkflowParameter", "workflow_property", "dynamic_workflow_condition"]


import re
Expand All @@ -20,6 +20,7 @@
from law.task.base import Task, Register
from law.task.proxy import ProxyTask, get_proxy_attribute
from law.target.collection import TargetCollection
from law.target.local import LocalFileTarget
from law.parameter import NO_STR, MultiRangeParameter, CSVParameter
from law.util import (
no_value, make_list, make_set, iter_chunks, range_expand, range_join, create_hash,
Expand Down Expand Up @@ -234,14 +235,222 @@ def serialize(self, value):
return super(WorkflowParameter, self).serialize(value)


class dynamic_workflow_condition(object):
"""
Decorator for a workflow method that defines whether the branch map can be dynamically
constructed or whether a placeholder should be used until the condition is met. Similar to
Python's ``property``, the decorated object provides additional attributes for decorating other
methods that usually depend on the branch map, such as branch requirements or outputs.
Example:
:: code-block:: python
class MyWorkflow(law.LocalWorkflow):
def workflow_requires(self):
# define requirements for the full workflow to start
reqs = super().workflow_requires()
reqs["files"] = OtherTask.req(self)
return reqs
@law.dynamic_workflow_condition
def workflow_condition(self):
# declare that the branch map can be built if the workflow requirement exists
# note: self.input() refers to the outputs of tasks defined in workflow_requires()
return self.input()["files"].exists()
@workflow_condition.create_branch_map
def create_branch_map(self):
# let's assume that OtherTask produces a json file containing a list of objects
# that _this_ workflows iterates over, so we can simply return this list here
return self.input()["files"].load(formatter="json")
def requires(self):
# branch-level requirement
# note: this is not really necessary, since the branch requirements are only
# evaluated _after_ a branch map is built, so OtherTask must have been completed
return OtherTask.req(self)
@workflow_condition.output
def output(self):
# define the output
return law.LocalFileTarget("file_{}.txt".format(self.branch))
def run(self):
# trivial run implementation
self.output().touch()
The condition is defined by ``workflow_condition`` which is decorated by *this* object. Once it
is met, the branch map is fully created and cached (as usual) for subsequent calls.
In addition, both ``create_branch_map()`` and ``output()`` are decorated with corresponding
attributes of the initially decorated object. As a result, both methods will return placeholder
objects as long as the condition is not met - the branch map will be considered empty and the
output will refer to a temporary placeholder target that is never created. Note that a third
decorator for ``requires`` exists as well.
As a consequence, the amended workflow is fully dynamic with its exact shape potentially
depending heavily on conditions that are only known at runtime.
"""

_decorator_result = object()

def __init__(
self,
condition_fn,
create_branch_map_fn=None,
requires_fn=None,
output_fn=None,
):
super().__init__()

# attributes
self._condition_fn = condition_fn
self._create_branch_map_fn = create_branch_map_fn
self._requires_fn = requires_fn
self._output_fn = output_fn

def _wrap_condition_fn(self):
if self._condition_fn is None:
return None

@functools.wraps(self._condition_fn)
def condition(inst, *args, **kwargs):
return self._condition_fn(inst.as_workflow(), *args, **kwargs)

return condition

def create_branch_map(self, create_branch_map_fn):
# check the decorator method name
if create_branch_map_fn.__name__ != "create_branch_map":
raise NameError(
"the method decorated by dynamic_workflow_condition.create_branch_map should be "
"named 'create_branch_map', but got '{}'".format(create_branch_map_fn.__name__),
)

# store the function
self._create_branch_map_fn = create_branch_map_fn

return self._decorator_result

def _wrap_create_branch_map(self, bound_condition_fn):
if self._create_branch_map_fn is None:
return None

@functools.wraps(self._create_branch_map_fn)
def create_branch_map(inst, *args, **kwargs):
if not bound_condition_fn():
return [None]

# enable branch map caching since the condition is met
inst.cache_branch_map = True

return self._create_branch_map_fn(inst, *args, **kwargs)

return create_branch_map

def requires(self, requires_fn):
# check the decorator method name
if requires_fn.__name__ != "requires":
raise NameError(
"the method decorated by dynamic_workflow_condition.requires should be "
"named 'requires', but got '{}'".format(requires_fn.__name__),
)

# store the function
self._requires_fn = requires_fn

return self._decorator_result

def _wrap_requires(self, bound_condition_fn):
if self._requires_fn is None:
return None

@functools.wraps(self._requires_fn)
def requires(inst, *args, **kwargs):
if not bound_condition_fn():
return []

# enable branch map caching since the condition is met
inst.cache_branch_map = True

return self._requires_fn(inst, *args, **kwargs)

return requires

def output(self, output_fn):
# check the decorator method name
if output_fn.__name__ != "output":
raise NameError(
"the method decorated by dynamic_workflow_condition.output should be "
"named 'output', but got '{}'".format(output_fn.__name__),
)

# store the function
self._output_fn = output_fn

return self._decorator_result

def _wrap_output(self, bound_condition_fn):
if self._output_fn is None:
return None

@functools.wraps(self._output_fn)
def output(inst, *args, **kwargs):
if not bound_condition_fn():
return LocalFileTarget(is_tmp="DYNAMIC_WORKFLOW_PLACEHOLDER")

# enable branch map caching since the condition is met
inst.cache_branch_map = True

return self._output_fn(inst, *args, **kwargs)

return output

def _iter_wrappers(self, bound_condition_fn):
if self._create_branch_map_fn is not None:
yield "create_branch_map", self._wrap_create_branch_map(bound_condition_fn)

if self._requires_fn is not None:
yield "requires", self._wrap_requires(bound_condition_fn)

if self._output_fn is not None:
yield "output", self._wrap_output(bound_condition_fn)


class WorkflowRegister(Register):

def __init__(cls, name, bases, classdict):
super(WorkflowRegister, cls).__init__(name, bases, classdict)
def __new__(metacls, name, bases, classdict):
# handle dynamic workflow conditions
condition_attr = metacls.check_dynamic_workflow_conditions(name, classdict)
if condition_attr:
# store the attribute when found and disable the branch map caching by default
classdict["_condition_attr"] = condition_attr
classdict.setdefault("cache_branch_map_default", False)

# store a flag on the created class whether it defined a new workflow_proxy_cls
# this flag will define the classes in the mro to consider for instantiating the proxy
cls._defined_workflow_proxy = "workflow_proxy_cls" in classdict
classdict["_defined_workflow_proxy"] = "workflow_proxy_cls" in classdict

# create and return the class
return super(WorkflowRegister, metacls).__new__(metacls, name, bases, classdict)

@classmethod
def check_dynamic_workflow_conditions(metacls, name, classdict):
# check that only one condition is present in classdict
condition_attr = None
for attr, value in classdict.items():
if not isinstance(value, dynamic_workflow_condition):
continue
if condition_attr:
raise Exception(
"class '{}' defined with more than one dynamic_workflow_condition, found "
"'{}' after previously registered '{}'".format(name, attr, condition_attr),
)
condition_attr = attr

return condition_attr


class BaseWorkflow(six.with_metaclass(WorkflowRegister, Task)):
Expand Down Expand Up @@ -407,6 +616,9 @@ class BaseWorkflow(six.with_metaclass(WorkflowRegister, Task)):
"'start:end' (end not included as per Python) to support range syntax; default: empty",
)

# caches
_cls_branch_map_cache = {}

# configuration members
workflow_proxy_cls = BaseWorkflowProxy
output_collection_cls = None
Expand All @@ -418,9 +630,6 @@ class BaseWorkflow(six.with_metaclass(WorkflowRegister, Task)):
passthrough_requested_workflow = True
workflow_run_decorators = None

# caches
_cls_branch_map_cache = {}

# skip from indexing
exclude_index = True

Expand All @@ -431,6 +640,26 @@ class BaseWorkflow(six.with_metaclass(WorkflowRegister, Task)):
exclude_params_branch = {"acceptance", "tolerance", "pilot", "branches"}
exclude_params_workflow = {"branch"}

def __new__(cls, *args, **kwargs):
inst = super(BaseWorkflow, cls).__new__(cls)

# bind wrappers present in the optional condition object
condition_attr = getattr(cls, "_condition_attr", None)
if condition_attr:
condition = getattr(inst, condition_attr, None)
if isinstance(condition, dynamic_workflow_condition):
# bind the condition method itself
bound_condition_fn = condition._wrap_condition_fn().__get__(inst)
setattr(inst, condition_attr, bound_condition_fn)

# bind wrapped methods that currently correspond to placeholders
for attr, wrapper in condition._iter_wrappers(bound_condition_fn):
if getattr(inst, attr, None) != dynamic_workflow_condition._decorator_result:
continue
setattr(inst, attr, wrapper.__get__(inst))

return inst

@classmethod
def modify_param_values(cls, params):
params = super(BaseWorkflow, cls).modify_param_values(params)
Expand Down

0 comments on commit 5173f91

Please sign in to comment.