Skip to content

Commit

Permalink
Accelerate local workflow completion checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Dec 17, 2024
1 parent f61f2ae commit cf41ed6
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 18 deletions.
4 changes: 4 additions & 0 deletions bin/githooks/post-commit
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
# In case a variable LAW_GITHOOKS_SKIP is set, the hook is skipped.

action() {
# original lfs post-commit hook
command -v git-lfs >/dev/null 2>&1 || { printf >&2 "\n%s\n\n" "This repository is configured for Git LFS but 'git-lfs' was not found on your path. If you no longer wish to use Git LFS, remove this hook by deleting the 'post-commit' file in the hooks directory (set by 'core.hookspath'; usually '.git/hooks')."; exit 2; }
git lfs post-commit "$@" || return "$?"

[ ! -z "${LAW_GITHOOKS_SKIP}" ] && return "0"

local shell_is_zsh="$( [ -z "${ZSH_VERSION}" ] && echo "false" || echo "true" )"
Expand Down
6 changes: 3 additions & 3 deletions law/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def setup_logging() -> None:

# setup the main law logger first and set its handler which is propagated to subloggers
logger = get_logger("law", skip_setup=True)
logger = setup_logger(logger, add_console_handler=False)
logger = setup_logger(logger, add_console_handler=False) # type: ignore[assignment]
logger.addHandler(create_stream_handler())

# set levels for all loggers and add the console handler for all non-law loggers
Expand All @@ -246,7 +246,7 @@ def _logger_setup(logger: logging.Logger, value: bool | None = None) -> bool:
return getattr(logger, attr, False)


def get_logger(*args, skip_setup: bool = False, **kwargs) -> logging.Logger:
def get_logger(*args, skip_setup: bool = False, **kwargs) -> Logger:
"""
Replacement for *logging.getLogger* that makes sure that the custom :py:class:`Logger` class is
used when new loggers are created and that the logger is properly set up by
Expand All @@ -261,7 +261,7 @@ def get_logger(*args, skip_setup: bool = False, **kwargs) -> logging.Logger:
if not skip_setup:
setup_logger(logger)

return logger
return logger # type: ignore[return-value]
finally:
logging.setLoggerClass(orig_cls)

Expand Down
45 changes: 45 additions & 0 deletions law/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import re
import functools
import copy
import multiprocessing
import logging

import luigi # type: ignore[import-untyped]
import law
from law.task.base import BaseTask
from law.logger import get_logger
from law._types import Callable

Expand Down Expand Up @@ -63,6 +65,7 @@ def patch_all() -> None:
patch_interface_logging()
patch_parameter_copy()
patch_parameter_parse_or_no_value()
patch_worker_check_complete_cached()

logger.debug("applied all law-specific luigi patches")

Expand Down Expand Up @@ -512,3 +515,45 @@ def _parse_or_no_value(self, x):
luigi.parameter.Parameter._parse_or_no_value = _parse_or_no_value

logger.debug("patched luigi.parameter.Parameter._parse_or_no_value")


def patch_worker_check_complete_cached() -> None:
"""
Patches the ``luigi.worker.check_complete_cached`` function to treat cached task completeness
decision slightly differently. The original implementation only skips the completeness check and
uses the cached value if, and only if, a task was actually already marked as complete. Missing
or *False* entries are both neglected and the completeness check is performed. Now, *False*
entries also cause the check to be skipped, considering the task as incomplete. However, after
that, the cache entry is removed so that subsequent checks are performed as usual.
"""
def check_complete_cached(
task: BaseTask,
completion_cache: multiprocessing.managers.DictProxy | None = None,
) -> bool:
# no caching behavior when no cache is given
if completion_cache is None:
return task.complete()

# get the cached state
cache_key = task.task_id
complete = completion_cache.get(cache_key)

# stop when already complete
if complete:
return True

# consider as incomplete when the cache entry is falsy, yet not None
if not complete and complete is not None:
completion_cache.pop(cache_key, None)
return False

# check the status and tell the cache when complete
complete = task.complete()
if complete:
completion_cache[cache_key] = complete

return complete

luigi.worker.check_complete_cached = check_complete_cached

logger.debug("patched luigi.worker.check_complete_cached")
26 changes: 25 additions & 1 deletion law/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(self, *args, **kwargs) -> None:

self._workflow_has_reset_branch_map = False

# cached outputs
self._cached_output: dict[str, Any] | None = None

def _get_task_attribute(self, name: str | Sequence[str], *, fallback: bool = False) -> Any:
"""
Return an attribute of the actual task named ``<workflow_type>_<name>``. When the attribute
Expand Down Expand Up @@ -126,7 +129,7 @@ def requires(self) -> Any:
reqs.update(workflow_reqs)
return reqs

def output(self) -> Any:
def output(self) -> dict[str, Any]:
"""
Returns the default workflow outputs in an ordered dictionary. At the moment this is just
the collection of outputs of the branch tasks, stored with the key ``"collection"``.
Expand All @@ -145,6 +148,27 @@ def output(self) -> Any:
collection = cls(targets, threshold=self.threshold(len(targets)))
return DotDict([("collection", collection)])

def get_cached_output(self, update: bool = False) -> dict[str, Any]:
"""
If already cached, returns the previously computed output, and otherwise computes it via
:py:meth:`output` and caches it for subsequent calls, if :py:attr:`cache_brach_map` of the
task is *True*.
"""
# invalidate cache
if update:
self._cached_output = None

# return from cache if present
if self._cached_output is not None:
return self._cached_output

# get output and cache it if possible
output = self.output()
if self.task.cache_branch_map: # type: ignore[attr-defined]
self._cached_output = output

return output

def threshold(self, n: int | None = None) -> float | int:
"""
Returns the threshold number of tasks that need to be complete in order to consider the
Expand Down
50 changes: 47 additions & 3 deletions law/workflow/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@

import luigi # type: ignore[import-untyped]

from law.task.base import BaseTask
from law.workflow.base import BaseWorkflow, BaseWorkflowProxy
from law.target.collection import SiblingFileCollectionBase
from law.logger import get_logger
from law.util import DotDict
from law._types import Any, Iterator
from law._types import Any, Iterator, Callable


logger = get_logger(__name__)


class LocalWorkflowProxy(BaseWorkflowProxy):
Expand Down Expand Up @@ -61,10 +67,48 @@ def run(self) -> None | Iterator[Any]:
self._local_workflow_has_yielded = True

# use branch tasks as requirements
reqs = list(task.get_branch_tasks().values())
branch_tasks = task.get_branch_tasks()
reqs = list(branch_tasks.values())

# helper to get the output collection
get_col = lambda: self.get_cached_output().get("collection")

# in case the workflows creates a sibling file collection, per-branch completion
# checks are possible in advance and can be stored in luigi's completion cache
def custom_complete(complete_fn: Callable[[BaseTask], bool]) -> bool:
# get the cache (stored as a specified keyword of a partial'ed function)
cache = getattr(complete_fn, "keywords", {}).get("completion_cache")
if cache is None:
if complete_fn(self):
return True
# show a warning for large workflows that use sibling file collections and
# that could profit from the cache_task_completion feature
if len(reqs) >= 100 and isinstance(get_col(), SiblingFileCollectionBase):
url = "https://luigi.readthedocs.io/en/stable/configuration.html#worker"
logger.warning_once(
"cache_task_completion_hint",
"detected SiblingFileCollection for LocalWorkflow with {} branches "
"whose completness checks will be performed manually by luigi; "
"consider enabling luigi's cache_task_completion feature to speed "
"up these checks; fore more info, see {}".format(len(reqs), url),
)
return False

# the output collection must be a sibling file collection
col = get_col()
if not isinstance(col, SiblingFileCollectionBase):
return complete_fn(self)

# get existing branches and populate the cache with completeness states
existing_branches = set(col.count(keys=True)[1]) # type: ignore[index]
for b, task in branch_tasks.items():
cache[task.task_id] = b in existing_branches

# finally, evaluate the normal completeness check on the workflow
return complete_fn(self)

# wrap into DynamicRequirements
yield luigi.DynamicRequirements(reqs, lambda complete_fn: complete_fn(self))
yield luigi.DynamicRequirements(reqs, custom_complete)

return None

Expand Down
14 changes: 3 additions & 11 deletions law/workflow/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ def __init__(self, *args, **kwargs) -> None:
# retry counts per job num
self._job_retries: dict[int, int] = defaultdict(int)

# cached output() return value
self._cached_output: dict | None = None

# flag that denotes whether a submission was done befire, set in run()
self._submitted = False

Expand Down Expand Up @@ -394,11 +391,6 @@ def _cleanup_jobs(self) -> bool:
task: BaseRemoteWorkflow = self.task # type: ignore[assignment]
return isinstance(getattr(task, "cleanup_jobs", None), bool) and task.cleanup_jobs # type: ignore[return-value] # noqa

def _get_cached_output(self) -> dict:
if self._cached_output is None:
self._cached_output = self.output()
return self._cached_output

def _get_existing_branches(
self,
sync: bool = False,
Expand All @@ -412,7 +404,7 @@ def _get_existing_branches(
self._existing_branches = set()
# add initial branches existing in output collection
if collection is None:
collection = self._get_cached_output().get("collection")
collection = self.get_cached_output().get("collection")
if collection is not None:
keys = collection.count(existing=True, keys=True)[1] # type: ignore[index]
self._existing_branches |= set(keys)
Expand Down Expand Up @@ -691,7 +683,7 @@ def dump_job_data(self) -> None:
self.job_data["dashboard_config"] = self.dashboard.get_persistent_config()

# write the job data to the output file
output = self._get_cached_output()
output = self.get_cached_output()
if output is not None:
with self._dump_lock:
output["jobs"].dump(self.job_data, formatter="json", indent=4)
Expand All @@ -712,7 +704,7 @@ def _run_impl(self) -> None:
"""
task: BaseRemoteWorkflow = self.task # type: ignore[assignment]

output = self._get_cached_output()
output = self.get_cached_output()
if not isinstance(output, dict):
raise TypeError(f"workflow output must be a dict, got '{output}'")

Expand Down

0 comments on commit cf41ed6

Please sign in to comment.