From cf41ed68ea70110b6f6353c5512162d2e3faa6cc Mon Sep 17 00:00:00 2001 From: "Marcel R." Date: Tue, 17 Dec 2024 12:46:48 +0100 Subject: [PATCH] Accelerate local workflow completion checks. --- bin/githooks/post-commit | 4 ++++ law/logger.py | 6 ++--- law/patches.py | 45 ++++++++++++++++++++++++++++++++++++ law/workflow/base.py | 26 ++++++++++++++++++++- law/workflow/local.py | 50 +++++++++++++++++++++++++++++++++++++--- law/workflow/remote.py | 14 +++-------- 6 files changed, 127 insertions(+), 18 deletions(-) diff --git a/bin/githooks/post-commit b/bin/githooks/post-commit index f1973375..5f202161 100755 --- a/bin/githooks/post-commit +++ b/bin/githooks/post-commit @@ -4,6 +4,10 @@ # In case a variable LAW_GITHOOKS_SKIP is set, the hook is skipped. action() { + # original lfs post-commit hook + command -v git-lfs >/dev/null 2>&1 || { printf >&2 "\n%s\n\n" "This repository is configured for Git LFS but 'git-lfs' was not found on your path. If you no longer wish to use Git LFS, remove this hook by deleting the 'post-commit' file in the hooks directory (set by 'core.hookspath'; usually '.git/hooks')."; exit 2; } + git lfs post-commit "$@" || return "$?" + [ ! -z "${LAW_GITHOOKS_SKIP}" ] && return "0" local shell_is_zsh="$( [ -z "${ZSH_VERSION}" ] && echo "false" || echo "true" )" diff --git a/law/logger.py b/law/logger.py index e82e7715..3fa41045 100644 --- a/law/logger.py +++ b/law/logger.py @@ -230,7 +230,7 @@ def setup_logging() -> None: # setup the main law logger first and set its handler which is propagated to subloggers logger = get_logger("law", skip_setup=True) - logger = setup_logger(logger, add_console_handler=False) + logger = setup_logger(logger, add_console_handler=False) # type: ignore[assignment] logger.addHandler(create_stream_handler()) # set levels for all loggers and add the console handler for all non-law loggers @@ -246,7 +246,7 @@ def _logger_setup(logger: logging.Logger, value: bool | None = None) -> bool: return getattr(logger, attr, False) -def get_logger(*args, skip_setup: bool = False, **kwargs) -> logging.Logger: +def get_logger(*args, skip_setup: bool = False, **kwargs) -> Logger: """ Replacement for *logging.getLogger* that makes sure that the custom :py:class:`Logger` class is used when new loggers are created and that the logger is properly set up by @@ -261,7 +261,7 @@ def get_logger(*args, skip_setup: bool = False, **kwargs) -> logging.Logger: if not skip_setup: setup_logger(logger) - return logger + return logger # type: ignore[return-value] finally: logging.setLoggerClass(orig_cls) diff --git a/law/patches.py b/law/patches.py index 3e4ddf1d..ea42720d 100644 --- a/law/patches.py +++ b/law/patches.py @@ -12,10 +12,12 @@ import re import functools import copy +import multiprocessing import logging import luigi # type: ignore[import-untyped] import law +from law.task.base import BaseTask from law.logger import get_logger from law._types import Callable @@ -63,6 +65,7 @@ def patch_all() -> None: patch_interface_logging() patch_parameter_copy() patch_parameter_parse_or_no_value() + patch_worker_check_complete_cached() logger.debug("applied all law-specific luigi patches") @@ -512,3 +515,45 @@ def _parse_or_no_value(self, x): luigi.parameter.Parameter._parse_or_no_value = _parse_or_no_value logger.debug("patched luigi.parameter.Parameter._parse_or_no_value") + + +def patch_worker_check_complete_cached() -> None: + """ + Patches the ``luigi.worker.check_complete_cached`` function to treat cached task completeness + decision slightly differently. The original implementation only skips the completeness check and + uses the cached value if, and only if, a task was actually already marked as complete. Missing + or *False* entries are both neglected and the completeness check is performed. Now, *False* + entries also cause the check to be skipped, considering the task as incomplete. However, after + that, the cache entry is removed so that subsequent checks are performed as usual. + """ + def check_complete_cached( + task: BaseTask, + completion_cache: multiprocessing.managers.DictProxy | None = None, + ) -> bool: + # no caching behavior when no cache is given + if completion_cache is None: + return task.complete() + + # get the cached state + cache_key = task.task_id + complete = completion_cache.get(cache_key) + + # stop when already complete + if complete: + return True + + # consider as incomplete when the cache entry is falsy, yet not None + if not complete and complete is not None: + completion_cache.pop(cache_key, None) + return False + + # check the status and tell the cache when complete + complete = task.complete() + if complete: + completion_cache[cache_key] = complete + + return complete + + luigi.worker.check_complete_cached = check_complete_cached + + logger.debug("patched luigi.worker.check_complete_cached") diff --git a/law/workflow/base.py b/law/workflow/base.py index 15ba9ce8..2ac4b2d0 100644 --- a/law/workflow/base.py +++ b/law/workflow/base.py @@ -77,6 +77,9 @@ def __init__(self, *args, **kwargs) -> None: self._workflow_has_reset_branch_map = False + # cached outputs + self._cached_output: dict[str, Any] | None = None + def _get_task_attribute(self, name: str | Sequence[str], *, fallback: bool = False) -> Any: """ Return an attribute of the actual task named ``_``. When the attribute @@ -126,7 +129,7 @@ def requires(self) -> Any: reqs.update(workflow_reqs) return reqs - def output(self) -> Any: + def output(self) -> dict[str, Any]: """ Returns the default workflow outputs in an ordered dictionary. At the moment this is just the collection of outputs of the branch tasks, stored with the key ``"collection"``. @@ -145,6 +148,27 @@ def output(self) -> Any: collection = cls(targets, threshold=self.threshold(len(targets))) return DotDict([("collection", collection)]) + def get_cached_output(self, update: bool = False) -> dict[str, Any]: + """ + If already cached, returns the previously computed output, and otherwise computes it via + :py:meth:`output` and caches it for subsequent calls, if :py:attr:`cache_brach_map` of the + task is *True*. + """ + # invalidate cache + if update: + self._cached_output = None + + # return from cache if present + if self._cached_output is not None: + return self._cached_output + + # get output and cache it if possible + output = self.output() + if self.task.cache_branch_map: # type: ignore[attr-defined] + self._cached_output = output + + return output + def threshold(self, n: int | None = None) -> float | int: """ Returns the threshold number of tasks that need to be complete in order to consider the diff --git a/law/workflow/local.py b/law/workflow/local.py index cc456e3c..c27ce981 100644 --- a/law/workflow/local.py +++ b/law/workflow/local.py @@ -12,9 +12,15 @@ import luigi # type: ignore[import-untyped] +from law.task.base import BaseTask from law.workflow.base import BaseWorkflow, BaseWorkflowProxy +from law.target.collection import SiblingFileCollectionBase +from law.logger import get_logger from law.util import DotDict -from law._types import Any, Iterator +from law._types import Any, Iterator, Callable + + +logger = get_logger(__name__) class LocalWorkflowProxy(BaseWorkflowProxy): @@ -61,10 +67,48 @@ def run(self) -> None | Iterator[Any]: self._local_workflow_has_yielded = True # use branch tasks as requirements - reqs = list(task.get_branch_tasks().values()) + branch_tasks = task.get_branch_tasks() + reqs = list(branch_tasks.values()) + + # helper to get the output collection + get_col = lambda: self.get_cached_output().get("collection") + + # in case the workflows creates a sibling file collection, per-branch completion + # checks are possible in advance and can be stored in luigi's completion cache + def custom_complete(complete_fn: Callable[[BaseTask], bool]) -> bool: + # get the cache (stored as a specified keyword of a partial'ed function) + cache = getattr(complete_fn, "keywords", {}).get("completion_cache") + if cache is None: + if complete_fn(self): + return True + # show a warning for large workflows that use sibling file collections and + # that could profit from the cache_task_completion feature + if len(reqs) >= 100 and isinstance(get_col(), SiblingFileCollectionBase): + url = "https://luigi.readthedocs.io/en/stable/configuration.html#worker" + logger.warning_once( + "cache_task_completion_hint", + "detected SiblingFileCollection for LocalWorkflow with {} branches " + "whose completness checks will be performed manually by luigi; " + "consider enabling luigi's cache_task_completion feature to speed " + "up these checks; fore more info, see {}".format(len(reqs), url), + ) + return False + + # the output collection must be a sibling file collection + col = get_col() + if not isinstance(col, SiblingFileCollectionBase): + return complete_fn(self) + + # get existing branches and populate the cache with completeness states + existing_branches = set(col.count(keys=True)[1]) # type: ignore[index] + for b, task in branch_tasks.items(): + cache[task.task_id] = b in existing_branches + + # finally, evaluate the normal completeness check on the workflow + return complete_fn(self) # wrap into DynamicRequirements - yield luigi.DynamicRequirements(reqs, lambda complete_fn: complete_fn(self)) + yield luigi.DynamicRequirements(reqs, custom_complete) return None diff --git a/law/workflow/remote.py b/law/workflow/remote.py index 6a145aa2..b9bcc7db 100644 --- a/law/workflow/remote.py +++ b/law/workflow/remote.py @@ -233,9 +233,6 @@ def __init__(self, *args, **kwargs) -> None: # retry counts per job num self._job_retries: dict[int, int] = defaultdict(int) - # cached output() return value - self._cached_output: dict | None = None - # flag that denotes whether a submission was done befire, set in run() self._submitted = False @@ -394,11 +391,6 @@ def _cleanup_jobs(self) -> bool: task: BaseRemoteWorkflow = self.task # type: ignore[assignment] return isinstance(getattr(task, "cleanup_jobs", None), bool) and task.cleanup_jobs # type: ignore[return-value] # noqa - def _get_cached_output(self) -> dict: - if self._cached_output is None: - self._cached_output = self.output() - return self._cached_output - def _get_existing_branches( self, sync: bool = False, @@ -412,7 +404,7 @@ def _get_existing_branches( self._existing_branches = set() # add initial branches existing in output collection if collection is None: - collection = self._get_cached_output().get("collection") + collection = self.get_cached_output().get("collection") if collection is not None: keys = collection.count(existing=True, keys=True)[1] # type: ignore[index] self._existing_branches |= set(keys) @@ -691,7 +683,7 @@ def dump_job_data(self) -> None: self.job_data["dashboard_config"] = self.dashboard.get_persistent_config() # write the job data to the output file - output = self._get_cached_output() + output = self.get_cached_output() if output is not None: with self._dump_lock: output["jobs"].dump(self.job_data, formatter="json", indent=4) @@ -712,7 +704,7 @@ def _run_impl(self) -> None: """ task: BaseRemoteWorkflow = self.task # type: ignore[assignment] - output = self._get_cached_output() + output = self.get_cached_output() if not isinstance(output, dict): raise TypeError(f"workflow output must be a dict, got '{output}'")