diff --git a/law/_types.py b/law/_types.py index 1a88c2d0..f82f6a64 100644 --- a/law/_types.py +++ b/law/_types.py @@ -29,7 +29,6 @@ if sys.version_info[:2] >= (3, 9): from types import GenericAlias # noqa - else: GenericAlias = str diff --git a/law/cli/index.py b/law/cli/index.py index e157ccfc..84b2384c 100644 --- a/law/cli/index.py +++ b/law/cli/index.py @@ -15,10 +15,10 @@ import luigi # type: ignore[import-untyped] from law.config import Config -from law.task.base import Register, Task, ExternalTask +from law.task.base import Task, ExternalTask from law.util import multi_match, colored, abort, makedirs, brace_expand from law.logger import get_logger -from law._types import Sequence +from law._types import Sequence, Type logger = get_logger(__name__) @@ -151,9 +151,9 @@ def execute(args: argparse.Namespace) -> int: # determine tasks to write into the index file seen_families = [] task_classes = [] - lookup: list[Register] = [Task] + lookup: list[Type[Task]] = [Task] while lookup: - cls: Register = lookup.pop(0) # type: ignore + cls: Type[Task] = lookup.pop(0) # type: ignore lookup.extend(cls.__subclasses__()) # skip tasks in __main__ module in interactive sessions @@ -214,7 +214,7 @@ def execute(args: argparse.Namespace) -> int: task_classes.append(cls) - def get_task_params(cls: Register) -> list[str]: + def get_task_params(cls) -> list[str]: params = [] for attr in dir(cls): member = getattr(cls, attr) diff --git a/law/contrib/arc/__init__.py b/law/contrib/arc/__init__.py index c5dcf5f9..6bcd73db 100644 --- a/law/contrib/arc/__init__.py +++ b/law/contrib/arc/__init__.py @@ -12,6 +12,9 @@ "ensure_arcproxy", ] +# dependencies to other contrib modules +import law +law.contrib.load("wlcg") # provisioning imports from law.contrib.arc.util import ( diff --git a/law/contrib/arc/decorator.py b/law/contrib/arc/decorator.py index af1a825c..f7724bf7 100644 --- a/law/contrib/arc/decorator.py +++ b/law/contrib/arc/decorator.py @@ -4,31 +4,41 @@ Decorators for task methods for convenient working with ARC. """ -__all__ = ["ensure_arcproxy"] +from __future__ import annotations +__all__ = ["ensure_arcproxy"] +from law.task.base import Task from law.decorator import factory +from law._types import Any, Callable + from law.contrib.arc import check_arcproxy_validity @factory(accept_generator=True) -def ensure_arcproxy(fn, opts, task, *args, **kwargs): +def ensure_arcproxy( + fn: Callable, + opts: dict[str, Any], + task: Task, + *args, + **kwargs, +) -> tuple[Callable, Callable, Callable]: """ Decorator for law task methods that checks the validity of the arc proxy and throws an exception in case it is invalid. This can prevent late errors on remote worker notes that except arc proxies to be present. Accepts generator functions. """ - def before_call(): + def before_call() -> None: # check the proxy validity if not check_arcproxy_validity(): raise Exception("arc proxy not valid") return None - def call(state): + def call(state: None) -> Any: return fn(task, *args, **kwargs) - def after_call(state): - return + def after_call(state: None) -> None: + return None return before_call, call, after_call diff --git a/law/contrib/arc/job.py b/law/contrib/arc/job.py index f3ca5d20..23d4c4eb 100644 --- a/law/contrib/arc/job.py +++ b/law/contrib/arc/job.py @@ -5,21 +5,24 @@ http://www.nordugrid.org/documents/xrsl.pdf. """ -__all__ = ["ARCJobManager", "ARCJobFileFactory"] +from __future__ import annotations +__all__ = ["ARCJobManager", "ARCJobFileFactory"] import os import stat import time import re import random +import pathlib import subprocess from law.config import Config -from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile, DeprecatedInputFiles +from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile from law.target.file import get_path from law.util import interruptable_popen, make_list, make_unique, quote_cmd from law.logger import get_logger +from law._types import Any, Sequence logger = get_logger(__name__) @@ -39,16 +42,30 @@ class ARCJobManager(BaseJobManager): status_block_cre = re.compile(r"\s*([^:]+): (.*)\n") status_invalid_job_cre = re.compile("^.+: Job not found in job list: (.+)$") status_missing_job_cre = re.compile( - "^.+: Job information not found in the information system: (.+)$") + "^.+: Job information not found in the information system: (.+)$", + ) - def __init__(self, job_list=None, ce=None, threads=1): - super(ARCJobManager, self).__init__() + def __init__( + self, + job_list: str | None = None, + ce: str | None = None, + threads: int = 1, + ) -> None: + super().__init__() self.job_list = job_list self.ce = ce self.threads = threads - def submit(self, job_file, job_list=None, ce=None, retries=0, retry_delay=3, silent=False): + def submit( # type: ignore[override] + self, + job_file: str | pathlib.Path | Sequence[str | pathlib.Path], + job_list: str | None = None, + ce: str | None = None, + retries: int = 0, + retry_delay: float | int = 3, + silent: bool = False, + ) -> str | list[str] | None: # default arguments if job_list is None: job_list = self.job_list @@ -58,7 +75,7 @@ def submit(self, job_file, job_list=None, ce=None, retries=0, retry_delay=3, sil # check arguments if not ce: raise ValueError("ce must not be empty") - ce = make_list(ce) + _ce = make_list(ce) # arc supports multiple jobs to be submitted with a single arcsub call, # so job_file can be a sequence of files @@ -72,16 +89,22 @@ def submit(self, job_file, job_list=None, ce=None, retries=0, retry_delay=3, sil # define the actual submission in a loop to simplify retries while True: # build the command - cmd = ["arcsub", "-c", random.choice(ce)] + cmd = ["arcsub", "-c", random.choice(_ce)] if job_list: cmd += ["-j", job_list] cmd += job_file_names - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run the command - logger.debug("submit arc job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, cwd=job_file_dir) + logger.debug(f"submit arc job(s) with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + cwd=job_file_dir, + ) # in some cases, the return code is 0 but the ce did not respond valid job ids job_ids = [] @@ -94,18 +117,18 @@ def submit(self, job_file, job_list=None, ce=None, retries=0, retry_delay=3, sil if not job_ids: code = 1 - out = "cannot find job id(s) in output:\n{}".format(out) + out = f"cannot find job id(s) in output:\n{out}" elif len(job_ids) != len(job_files): - raise Exception("number of job ids in output ({}) does not match number of " - "jobs to submit ({}) in output:\n{}".format(len(job_ids), len(job_files), - out)) + raise Exception( + f"number of job ids in output ({len(job_ids)}) does not match number of " + f"jobs to submit ({len(job_files)}) in output:\n{out}", + ) # retry or done? if code == 0: return job_ids if chunking else job_ids[0] - logger.debug("submission of arc job(s) '{}' failed with code {}:\n{}".format( - job_files, code, out)) + logger.debug(f"submission of arc job(s) '{job_files}' failed with code {code}:\n{out}") if retries > 0: retries -= 1 @@ -115,9 +138,14 @@ def submit(self, job_file, job_list=None, ce=None, retries=0, retry_delay=3, sil if silent: return None - raise Exception("submission of arc job(s) '{}' failed:\n{}".format(job_files, out)) + raise Exception(f"submission of arc job(s) '{job_files}' failed:\n{out}") - def cancel(self, job_id, job_list=None, silent=False): + def cancel( # type: ignore[override] + self, + job_id: str | Sequence[str], + job_list: str | None = None, + silent: bool = False, + ) -> dict[str, None] | None: # default arguments if job_list is None: job_list = self.job_list @@ -130,22 +158,31 @@ def cancel(self, job_id, job_list=None, silent=False): if job_list: cmd += ["-j", job_list] cmd += job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("cancel arc job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE) + logger.debug(f"cancel arc job(s) with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + ) # check success if code != 0 and not silent: # arc prints everything to stdout - raise Exception("cancellation of arc job(s) '{}' failed with code {}:\n{}".format( - job_id, code, out)) + raise Exception(f"cancellation of arc job(s) '{job_id}' failed with code {code}:\n{out}") return {job_id: None for job_id in job_ids} if chunking else None - def cleanup(self, job_id, job_list=None, silent=False): + def cleanup( # type: ignore[override] + self, + job_id: str | Sequence[str], + job_list: str | None = None, + silent: bool = False, + ) -> dict[str, None] | None: # default arguments if job_list is None: job_list = self.job_list @@ -158,22 +195,31 @@ def cleanup(self, job_id, job_list=None, silent=False): if job_list: cmd += ["-j", job_list] cmd += job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("cleanup arc job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE) + logger.debug(f"cleanup arc job(s) with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + ) # check success if code != 0 and not silent: # arc prints everything to stdout - raise Exception("cleanup of arc job(s) '{}' failed with code {}:\n{}".format( - job_id, code, out)) + raise Exception(f"cleanup of arc job(s) '{job_id}' failed with code {code}:\n{out}") return {job_id: None for job_id in job_ids} if chunking else None - def query(self, job_id, job_list=None, silent=False): + def query( # type: ignore[override] + self, + job_id: str | Sequence[str], + job_list: str | None = None, + silent: bool = False, + ) -> dict[int, dict[str, Any]] | dict[str, Any] | None: # default arguments if job_list is None: job_list = self.job_list @@ -186,21 +232,25 @@ def query(self, job_id, job_list=None, silent=False): if job_list: cmd += ["-j", job_list] cmd += job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("query arc job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + logger.debug(f"query arc job(s) with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) # handle errors if code != 0: if silent: return None - else: - # arc prints everything to stdout - raise Exception("status query of arc job(s) '{}' failed with code {}:\n{}".format( - job_id, code, out)) + # arc prints everything to stdout + raise Exception(f"status query of arc job(s) '{job_id}' failed with code {code}:\n{out}") # parse the output and extract the status per job query_data = self.parse_query_output(out) @@ -211,17 +261,18 @@ def query(self, job_id, job_list=None, silent=False): if not chunking: if silent: return None - else: - raise Exception("arc job(s) '{}' not found in query response".format( - job_id)) + raise Exception(f"arc job(s) '{job_id}' not found in query response") else: - query_data[_job_id] = self.job_status_dict(job_id=_job_id, status=self.FAILED, - error="job not found in query response") + query_data[_job_id] = self.job_status_dict( + job_id=_job_id, + status=self.FAILED, + error="job not found in query response", + ) - return query_data if chunking else query_data[job_id] + return query_data if chunking else query_data[job_id] # type: ignore[index] @classmethod - def parse_query_output(cls, out): + def parse_query_output(cls, out: str) -> dict[str, dict[str, Any]]: query_data = {} # first, check for invalid and missing jobs @@ -232,8 +283,12 @@ def parse_query_output(cls, out): m = cls.status_invalid_job_cre.match(line) if m: job_id = m.group(1) - query_data[job_id] = cls.job_status_dict(job_id=job_id, status=cls.FAILED, code=1, - error="job not found") + query_data[job_id] = cls.job_status_dict( + job_id=job_id, + status=cls.FAILED, + code=1, + error="job not found", + ) continue # missing job? this means that the job is not yet present in the information system @@ -252,7 +307,7 @@ def parse_query_output(cls, out): blocks = out.split("Job: ", 1)[1].strip().split("\nJob: ") for block in blocks: - data = dict(cls.status_block_cre.findall("Job: {}\n".format(block))) + data = dict(cls.status_block_cre.findall(f"Job: {block}\n")) if not data: continue @@ -262,7 +317,7 @@ def parse_query_output(cls, out): job_id = data["Job"] # interpret data - status = cls.map_status(data.get("State") or None) + status = cls.map_status(data.get("State")) code = data.get("Exit Code") and int(data["Exit Code"]) error = data.get("Job Error") or None @@ -271,25 +326,30 @@ def parse_query_output(cls, out): code = 1 # store it - query_data[job_id] = cls.job_status_dict(job_id=job_id, status=status, code=code, - error=error) + query_data[job_id] = cls.job_status_dict( + job_id=job_id, + status=status, + code=code, + error=error, + ) return query_data @classmethod - def map_status(cls, status): + def map_status(cls, status: str | None) -> str: # see http://www.nordugrid.org/documents/arc-ui.pdf if status in ("Queuing", "Accepted", "Preparing", "Submitting"): return cls.PENDING - elif status in ("Running", "Finishing"): + if status in ("Running", "Finishing"): return cls.RUNNING - elif status in ("Finished",): + if status in ("Finished",): return cls.FINISHED - elif status in ("Failed", "Deleted"): - return cls.FAILED - else: + if status in ("Failed", "Deleted"): return cls.FAILED + logger.debug(f"unknown arc job state '{status}'") + return cls.FAILED + class ARCJobFileFactory(BaseJobFileFactory): @@ -299,29 +359,50 @@ class ARCJobFileFactory(BaseJobFileFactory): "stderr", "custom_content", "absolute_paths", ] - def __init__(self, file_name="arc_job.xrsl", command=None, executable=None, arguments=None, - input_files=None, output_files=None, postfix_output_files=True, output_uri=None, - overwrite_output_files=True, job_name=None, log="log.txt", stdout="stdout.txt", - stderr="stderr.txt", custom_content=None, absolute_paths=True, **kwargs): + def __init__( + self, + file_name: str = "arc_job.xrsl", + command: str | Sequence[str] | None = None, + executable: str | None = None, + arguments: str | Sequence[str] | None = None, + input_files: dict[str, str | pathlib.Path | JobInputFile] | None = None, + output_files: list[str] | None = None, + postfix_output_files: bool = True, + output_uri: str | None = None, + overwrite_output_files: bool = True, + job_name: str | None = None, + log: str = "log.txt", + stdout: str = "stdout.txt", + stderr: str = "stderr.txt", + custom_content: str | Sequence[str] | None = None, + absolute_paths: bool = True, + **kwargs, + ) -> None: # get some default kwargs from the config cfg = Config.instance() if kwargs.get("dir") is None: - kwargs["dir"] = cfg.get_expanded("job", cfg.find_option("job", - "arc_job_file_dir", "job_file_dir")) + kwargs["dir"] = cfg.get_expanded( + "job", + cfg.find_option("job", "arc_job_file_dir", "job_file_dir"), + ) if kwargs.get("mkdtemp") is None: - kwargs["mkdtemp"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "arc_job_file_dir_mkdtemp", "job_file_dir_mkdtemp")) + kwargs["mkdtemp"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "arc_job_file_dir_mkdtemp", "job_file_dir_mkdtemp"), + ) if kwargs.get("cleanup") is None: - kwargs["cleanup"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "arc_job_file_dir_cleanup", "job_file_dir_cleanup")) + kwargs["cleanup"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "arc_job_file_dir_cleanup", "job_file_dir_cleanup"), + ) - super(ARCJobFileFactory, self).__init__(**kwargs) + super().__init__(**kwargs) self.file_name = file_name self.command = command self.executable = executable self.arguments = arguments - self.input_files = DeprecatedInputFiles(input_files or {}) + self.input_files = input_files or {} self.output_files = output_files or [] self.postfix_output_files = postfix_output_files self.output_uri = output_uri @@ -333,7 +414,11 @@ def __init__(self, file_name="arc_job.xrsl", command=None, executable=None, argu self.absolute_paths = absolute_paths self.custom_content = custom_content - def create(self, postfix=None, **kwargs): + def create( + self, + postfix: str | None = None, + **kwargs, + ) -> tuple[str, ARCJobFileFactory.Config]: # merge kwargs and instance attributes c = self.get_config(**kwargs) @@ -363,7 +448,7 @@ def create(self, postfix=None, **kwargs): } # special case: remote input files must never be copied - for f in c.input_files.values: + for f in c.input_files.values(): if f.is_remote: f.copy = False @@ -531,19 +616,18 @@ def prepare_output(path): f.write("&\n") for key, value in content: line = self.create_line(key, value) - f.write(line + "\n") + f.write(f"{line}\n") - logger.debug("created arc job file at '{}'".format(job_file)) + logger.debug(f"created arc job file at '{job_file}'") return job_file, c @classmethod - def create_line(cls, key, value): + def create_line(cls, key: str, value: Any) -> str: def flat_value(value): if isinstance(value, list): return " ".join(flat_value(v) for v in value) if isinstance(value, tuple): - return "({})".format(" ".join(flat_value(v) for v in value)) - else: - return "\"{}\"".format(value) - return "({} = {})".format(key, flat_value(value)) + return f"({' '.join(flat_value(v) for v in value)})" + return f"\"{value}\"" + return f"({key} = {flat_value(value)})" diff --git a/law/contrib/arc/util.py b/law/contrib/arc/util.py index 518afc99..cd1a7f45 100644 --- a/law/contrib/arc/util.py +++ b/law/contrib/arc/util.py @@ -4,24 +4,26 @@ Helpers for working with ARC. """ +from __future__ import annotations + __all__ = [ "get_arcproxy_file", "get_arcproxy_user", "get_arcproxy_lifetime", "get_arcproxy_vo", "check_arcproxy_validity", "renew_arcproxy", ] - import os import re +import pathlib import subprocess -from law.util import interruptable_popen, tmp_file, parse_duration, quote_cmd +from law.util import interruptable_popen, tmp_file, parse_duration, quote_cmd, custom_context from law.logger import get_logger logger = get_logger(__name__) -def get_arcproxy_file(): +def get_arcproxy_file() -> str: """ Returns the path to the arc proxy file. """ @@ -35,10 +37,14 @@ def get_arcproxy_file(): tmp = os.environ[v] break - return os.path.join(tmp, "x509up_u{}".format(os.getuid())) + return os.path.join(tmp, f"x509up_u{os.getuid()}") -def _arcproxy_info(args=None, proxy_file=None, silent=False): +def _arcproxy_info( + args: list[str] | None = None, + proxy_file: str | pathlib.Path | None = None, + silent: bool = False, +) -> tuple[int, str, str]: if args is None: args = ["--info"] cmd = ["arcproxy"] + (args or []) @@ -51,20 +57,27 @@ def _arcproxy_info(args=None, proxy_file=None, silent=False): proxy_file = os.path.expandvars(os.path.expanduser(str(proxy_file))) cmd.extend(["--proxy", proxy_file]) - code, out, err = interruptable_popen(quote_cmd(cmd), shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + quote_cmd(cmd), + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # arcproxy does not use proper exit codes but writes into stderr in case of an error if err: code = 1 if not silent and code != 0: - raise Exception("arcproxy failed: {}".format(err)) + raise Exception(f"arcproxy failed: {err}") return code, out, err -def get_arcproxy_user(proxy_file=None): +def get_arcproxy_user(proxy_file: str | pathlib.Path | None = None) -> str: """ Returns the owner of the arc proxy. When *proxy_file* is *None*, it defaults to the result of :py:func:`get_arcproxy_file`. Otherwise, when it evaluates to *False*, ``arcproxy`` is queried @@ -72,12 +85,12 @@ def get_arcproxy_user(proxy_file=None): """ out = _arcproxy_info(args=["--infoitem=identity"], proxy_file=proxy_file)[1].strip() try: - return re.match(r".*\/CN\=([^\/]+).*", out.strip()).group(1) + return re.match(r".*\/CN\=([^\/]+).*", out.strip()).group(1) # type: ignore[union-attr] except: - raise Exception("no valid identity found in arc proxy: {}".format(out)) + raise Exception(f"no valid identity found in arc proxy: {out}") -def get_arcproxy_lifetime(proxy_file=None): +def get_arcproxy_lifetime(proxy_file: str | pathlib.Path | None = None) -> int: """ Returns the remaining lifetime of the arc proxy in seconds. When *proxy_file* is *None*, it defaults to the result of :py:func:`get_arcproxy_file`. Otherwise, when it evaluates to @@ -87,10 +100,10 @@ def get_arcproxy_lifetime(proxy_file=None): try: return int(out) except: - raise Exception("no valid lifetime found in arc proxy: {}".format(out)) + raise Exception(f"no valid lifetime found in arc proxy: {out}") -def get_arcproxy_vo(proxy_file=None): +def get_arcproxy_vo(proxy_file: str | pathlib.Path | None = None) -> str: """ Returns the virtual organization name of the arc proxy. When *proxy_file* is *None*, it defaults to the result of :py:func:`get_arcproxy_file`. Otherwise, when it evaluates to *False*, @@ -99,7 +112,7 @@ def get_arcproxy_vo(proxy_file=None): return _arcproxy_info(args=["--infoitem=vomsVO"], proxy_file=proxy_file)[1].strip() -def check_arcproxy_validity(log=False, proxy_file=None): +def check_arcproxy_validity(log=False, proxy_file: str | pathlib.Path | None = None) -> bool: """ Returns *True* when a valid arc proxy exists, *False* otherwise. When *log* is *True*, a warning will be logged. When *proxy_file* is *None*, it defaults to the result of @@ -113,7 +126,7 @@ def check_arcproxy_validity(log=False, proxy_file=None): elif err.strip().lower().startswith("error: cannot find file at"): valid = False else: - raise Exception("arcproxy failed: {}".format(err)) + raise Exception(f"arcproxy failed: {err}") if log and not valid: logger.warning("no valid arc proxy found") @@ -121,14 +134,19 @@ def check_arcproxy_validity(log=False, proxy_file=None): return valid -def renew_arcproxy(password="", lifetime="8 days", proxy_file=None): +def renew_arcproxy( + password: str | pathlib.Path = "", + lifetime="8 days", + proxy_file: str | pathlib.Path | None = None, +) -> None: """ Renews the arc proxy using a password *password* and a default *lifetime* of 8 days, which is internally parsed by :py:func:`law.util.parse_duration` where the default input unit is hours. - To ensure that the *password* it is not visible in any process listing, it is written to a - temporary file first and piped into the ``arcproxy`` command. When *proxy_file* is *None*, it - defaults to the result of :py:func:`get_arcproxy_file`. Otherwise, when it evaluates to - *False*, ``arcproxy`` is invoked without a custom proxy file. + To ensure that the *password*, in case it is not passed as a file, is not visible in any process + listing, it is written to a temporary file first and piped into the ``arcproxy`` command. + + When *proxy_file* is *None*, it defaults to the result of :py:func:`get_arcproxy_file`. + Otherwise, when it evaluates to *False*, ``arcproxy`` is invoked without a custom proxy file. """ # convert the lifetime to seconds lifetime_seconds = int(parse_duration(lifetime, input_unit="h", unit="s")) @@ -136,18 +154,27 @@ def renew_arcproxy(password="", lifetime="8 days", proxy_file=None): if proxy_file is None: proxy_file = get_arcproxy_file() - args = "--constraint=validityPeriod={}".format(lifetime_seconds) + args = f"--constraint=validityPeriod={lifetime_seconds}" if proxy_file: proxy_file = os.path.expandvars(os.path.expanduser(str(proxy_file))) - args += " --proxy={}".format(proxy_file) - - with tmp_file() as (_, tmp): - with open(tmp, "w") as f: - f.write(password) - - cmd = "arcproxy --passwordsource=key=file:{} {}".format(tmp, args) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + args += f" --proxy={proxy_file}" + + password = str(password) + password_file_given = os.path.exists(password) + password_context = custom_context((None, password)) if password_file_given else tmp_file + with password_context() as (_, password_file): # type: ignore[operator] + if not password_file_given: + with open(password_file, "w") as f: + f.write(password) + + cmd = f"arcproxy --passwordsource=key=file:{password_file} {args}" + code, out, _ = interruptable_popen( + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) if code != 0: - raise Exception("arcproxy failed: {}".format(out)) + raise Exception(f"arcproxy failed: {out}") diff --git a/law/contrib/arc/workflow.py b/law/contrib/arc/workflow.py index b571c3d2..6a24943e 100644 --- a/law/contrib/arc/workflow.py +++ b/law/contrib/arc/workflow.py @@ -4,21 +4,25 @@ ARC remote workflow implementation. See http://www.nordugrid.org/arc/ce. """ -__all__ = ["ARCWorkflow"] +from __future__ import annotations +__all__ = ["ARCWorkflow"] import os -from abc import abstractmethod -from collections import OrderedDict +import pathlib +import abc from law.workflow.remote import BaseRemoteWorkflow, BaseRemoteWorkflowProxy -from law.job.base import JobArguments, JobInputFile, DeprecatedInputFiles +from law.job.base import JobArguments, JobInputFile from law.task.proxy import ProxyCommand from law.target.file import get_path +from law.target.local import LocalFileTarget from law.parameter import CSVParameter -from law.util import law_src_path, merge_dicts, DotDict +from law.util import law_src_path, merge_dicts, DotDict, InsertableDict from law.logger import get_logger +from law._types import Type +from law.contrib.wlcg import WLCGDirectoryTarget from law.contrib.arc.job import ARCJobManager, ARCJobFileFactory @@ -27,30 +31,34 @@ class ARCWorkflowProxy(BaseRemoteWorkflowProxy): - workflow_type = "arc" + workflow_type: str = "arc" - def __init__(self, *args, **kwargs): - super(ARCWorkflowProxy, self).__init__(*args, **kwargs) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # check if there is at least one ce if not self.task.arc_ce: raise Exception("please set at least one arc computing element (--arc-ce)") - def create_job_manager(self, **kwargs): + def create_job_manager(self, **kwargs) -> ARCJobManager: return self.task.arc_create_job_manager(**kwargs) - def create_job_file_factory(self, **kwargs): + def create_job_file_factory(self, **kwargs) -> ARCJobFileFactory: return self.task.arc_create_job_file_factory(**kwargs) - def create_job_file(self, job_num, branches): + def create_job_file( + self, + job_num: int, + branches: list[int], + ) -> dict[str, str | pathlib.Path | ARCJobFileFactory.Config | None]: task = self.task # the file postfix is pythonic range made from branches, e.g. [0, 1, 2, 4] -> "_0To5" - postfix = "_{}To{}".format(branches[0], branches[-1] + 1) + postfix = f"_{branches[0]}To{branches[-1] + 1}" # create the config - c = self.job_file_factory.get_config() - c.input_files = DeprecatedInputFiles() + c = self.job_file_factory.get_config() # type: ignore[union-attr] + c.input_files = {} c.output_files = [] c.render_variables = {} c.custom_content = [] @@ -80,18 +88,23 @@ def create_job_file(self, job_num, branches): ) if task.arc_use_local_scheduler(): proxy_cmd.add_arg("--local-scheduler", "True", overwrite=True) - for key, value in OrderedDict(task.arc_cmdline_args()).items(): + for key, value in dict(task.arc_cmdline_args()).items(): proxy_cmd.add_arg(key, value, overwrite=True) # job script arguments + dashboard_data = None + if self.dashboard is not None: + dashboard_data = self.dashboard.remote_hook_data( + job_num, + self.job_data.attempts.get(job_num, 0), + ) job_args = JobArguments( task_cls=task.__class__, task_params=proxy_cmd.build(skip_run=True), branches=branches, workers=task.job_workers, auto_retry=False, - dashboard_data=self.dashboard.remote_hook_data( - job_num, self.job_data.attempts.get(job_num, 0)), + dashboard_data=dashboard_data, ) c.arguments = job_args.join() @@ -106,9 +119,10 @@ def create_job_file(self, job_num, branches): c.input_files["stageout_file"] = stageout_file # does the dashboard have a hook file? - dashboard_file = self.dashboard.remote_hook_file() - if dashboard_file: - c.input_files["dashboard_file"] = dashboard_file + if self.dashboard is not None: + dashboard_file = self.dashboard.remote_hook_file() + if dashboard_file: + c.input_files["dashboard_file"] = dashboard_file # log files c.log = None @@ -121,14 +135,14 @@ def create_job_file(self, job_num, branches): c.custom_log_file = log_file # meta infos - c.job_name = "{}{}".format(task.live_task_id, postfix) + c.job_name = f"{task.live_task_id}{postfix}" c.output_uri = task.arc_output_uri() # task hook c = task.arc_job_config(c, job_num, branches) # build the job file and get the sanitized config - job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) + job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) # type: ignore[misc] # determine the custom log file uri if set abs_log_file = None @@ -138,10 +152,10 @@ def create_job_file(self, job_num, branches): # return job and log files return {"job": job_file, "config": c, "log": abs_log_file} - def destination_info(self): - info = super(ARCWorkflowProxy, self).destination_info() + def destination_info(self) -> InsertableDict: + info = super().destination_info() - info["ce"] = "ce: {}".format(",".join(self.task.arc_ce)) + info["ce"] = f"ce: {','.join(self.task.arc_ce)}" info = self.task.arc_destination_info(info) @@ -162,7 +176,7 @@ class ARCWorkflow(BaseRemoteWorkflow): description="target arc computing element(s); default: empty", ) - arc_job_kwargs = [] + arc_job_kwargs: list[str] = [] arc_job_kwargs_submit = ["arc_ce"] arc_job_kwargs_cancel = None arc_job_kwargs_cleanup = None @@ -170,65 +184,69 @@ class ARCWorkflow(BaseRemoteWorkflow): exclude_params_branch = {"arc_ce"} - exclude_params_arc_workflow = set() + exclude_params_arc_workflow: set[str] = set() exclude_index = True - @abstractmethod - def arc_output_directory(self): - return None + @abc.abstractmethod + def arc_output_directory(self) -> WLCGDirectoryTarget: + ... - @abstractmethod - def arc_bootstrap_file(self): + def arc_workflow_requires(self) -> DotDict: + return DotDict() + + def arc_bootstrap_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def arc_wrapper_file(self): + def arc_wrapper_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def arc_job_file(self): + def arc_job_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile: return JobInputFile(law_src_path("job", "law_job.sh")) - def arc_stageout_file(self): + def arc_stageout_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def arc_workflow_requires(self): - return DotDict() - - def arc_output_postfix(self): + def arc_output_postfix(self) -> str: return "" - def arc_output_uri(self): - return self.arc_output_directory().uri() + def arc_output_uri(self) -> str: + return self.arc_output_directory().uri(return_all=False) # type: ignore[return-value] - def arc_job_manager_cls(self): + def arc_job_manager_cls(self) -> Type[ARCJobManager]: return ARCJobManager - def arc_create_job_manager(self, **kwargs): + def arc_create_job_manager(self, **kwargs) -> ARCJobManager: kwargs = merge_dicts(self.arc_job_manager_defaults, kwargs) return self.arc_job_manager_cls()(**kwargs) - def arc_job_file_factory_cls(self): + def arc_job_file_factory_cls(self) -> Type[ARCJobFileFactory]: return ARCJobFileFactory - def arc_create_job_file_factory(self, **kwargs): + def arc_create_job_file_factory(self, **kwargs) -> ARCJobFileFactory: # job file fectory config priority: kwargs > class defaults kwargs = merge_dicts({}, self.arc_job_file_factory_defaults, kwargs) return self.arc_job_file_factory_cls()(**kwargs) - def arc_job_config(self, config, job_num, branches): + def arc_job_config( + self, + config: ARCJobFileFactory.Config, + job_num: int, + branches: list[int], + ) -> ARCJobFileFactory.Config: return config - def arc_check_job_completeness(self): + def arc_check_job_completeness(self) -> bool: return False - def arc_check_job_completeness_delay(self): + def arc_check_job_completeness_delay(self) -> float | int: return 0.0 - def arc_use_local_scheduler(self): + def arc_use_local_scheduler(self) -> bool: return True - def arc_cmdline_args(self): + def arc_cmdline_args(self) -> dict[str, str]: return {} - def arc_destination_info(self, info): + def arc_destination_info(self, info: InsertableDict) -> InsertableDict: return info diff --git a/law/contrib/cms/__init__.py b/law/contrib/cms/__init__.py index d7b06fd8..01618484 100644 --- a/law/contrib/cms/__init__.py +++ b/law/contrib/cms/__init__.py @@ -12,6 +12,9 @@ "Site", "lfn_to_pfn", "renew_vomsproxy", "delegate_myproxy", ] +# dependencies to other contrib modules +import law +law.contrib.load("wlcg") # provisioning imports from law.contrib.cms.sandbox import CMSSWSandbox diff --git a/law/contrib/cms/bin/apmon b/law/contrib/cms/bin/apmon index 2c7f4c9e..acf9df13 100755 --- a/law/contrib/cms/bin/apmon +++ b/law/contrib/cms/bin/apmon @@ -18,7 +18,6 @@ request. Usage: Further optional arguments: log_level, site, code, event """ - import sys import shlex import traceback diff --git a/law/contrib/cms/crab/PSet.py b/law/contrib/cms/crab/PSet.py index d89f9c82..2150eccd 100644 --- a/law/contrib/cms/crab/PSet.py +++ b/law/contrib/cms/crab/PSet.py @@ -4,7 +4,7 @@ Minimal valid configuration. """ -import FWCore.ParameterSet.Config as cms +import FWCore.ParameterSet.Config as cms # type: ignore[import-untyped, import-not-found] process = cms.Process("LAW") diff --git a/law/contrib/cms/job.py b/law/contrib/cms/job.py index b0280693..93356151 100644 --- a/law/contrib/cms/job.py +++ b/law/contrib/cms/job.py @@ -4,32 +4,36 @@ Crab job manager and CMS-related job helpers. """ -__all__ = ["CrabJobManager", "CrabJobFileFactory", "CMSJobDashboard"] +from __future__ import annotations +__all__ = ["CrabJobManager", "CrabJobFileFactory", "CMSJobDashboard"] import os import stat import time +import pathlib import socket import threading +import queue import re import json import subprocess import shutil -from collections import OrderedDict, namedtuple - -import six +import collections import law from law.config import Config +from law.task.base import Task from law.sandbox.base import Sandbox -from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile, DeprecatedInputFiles +from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile from law.job.dashboard import BaseJobDashboard +from law.workflow.remote import JobData from law.target.file import get_path from law.util import ( DotDict, interruptable_popen, make_list, make_unique, quote_cmd, no_value, rel_path, ) from law.logger import get_logger +from law._types import Any, MutableMapping, Callable, Type, Hashable, T, Sequence import law.contrib.cms.sandbox @@ -58,10 +62,16 @@ class CrabJobManager(BaseJobManager): job_grouping = True - JobId = namedtuple("JobId", ["crab_num", "task_name", "proj_dir"]) + JobId = collections.namedtuple("JobId", ["crab_num", "task_name", "proj_dir"]) - def __init__(self, sandbox_name=None, proxy=None, instance=None, threads=1): - super(CrabJobManager, self).__init__() + def __init__( + self, + sandbox_name: str | None = None, + proxy: str | None = None, + instance: str | None = None, + threads: int = 1, + ) -> None: + super().__init__() # default sandbox name if sandbox_name is None: @@ -72,7 +82,7 @@ def __init__(self, sandbox_name=None, proxy=None, instance=None, threads=1): self.cmssw_sandbox = Sandbox.new( sandbox_name if sandbox_name.startswith("cmssw::") - else "cmssw::{}".format(sandbox_name), + else f"cmssw::{sandbox_name}", ) # store attributes @@ -81,19 +91,25 @@ def __init__(self, sandbox_name=None, proxy=None, instance=None, threads=1): self.threads = threads @classmethod - def cast_job_id(cls, job_id): + def cast_job_id(cls, job_id: tuple[str]) -> CrabJobManager.JobId: """ Converts a *job_id*, for instance after json deserialization, into a :py:class:`JobId` object. """ - return cls.JobId(*job_id) if isinstance(job_id, (list, tuple)) else job_id + if isinstance(job_id, cls.JobId): + return job_id + + if isinstance(job_id, (list, tuple)): + return cls.JobId(*job_id) # type: ignore[call-arg] + + raise ValueError(f"cannot cast to {cls.JobId.__name__}: '{job_id!r}'") @property - def cmssw_env(self): + def cmssw_env(self) -> MutableMapping[str, Any]: return self.cmssw_sandbox.env - def group_job_ids(self, job_ids): - groups = OrderedDict() + def group_job_ids(self, job_ids: list[JobId]) -> dict[str, list[JobId]]: # type: ignore[override] # noqa + groups: dict[str, list[CrabJobManager.JobId]] = {} # group by project directory for job_id in job_ids: @@ -103,40 +119,64 @@ def group_job_ids(self, job_ids): return groups - def _apply_group(self, func, result_type, group_func, job_objs, *args, **kwargs): + def _apply_group( + self, + func: Callable, + result_type: Type[T], + group_func: Callable[[list[Any]], dict[Hashable, list[Any]]], + job_objs: list[Any], + threads: int | None = None, + callback: Callable[[int, Any], Any] | None = None, + **kwargs, + ) -> T: # when job_objs is a string or a sequence of strings, interpret them as project dirs, read # their log files to extract task names, build actual job ids and forward them if func != self.submit: job_ids = [] for i, job_id in enumerate(make_list(job_objs)): - if not isinstance(job_id, six.string_types): + if not isinstance(job_id, (str, pathlib.Path)): job_ids.append(job_id) continue # get n_jobs and task_name from log file proj_dir = job_id log_file = os.path.join(proj_dir, "crab.log") + if not os.path.exists(log_file): + job_ids.append(job_id) + continue log_data = self._parse_log_file(log_file) - if not log_data or "n_jobs" not in log_data or "task_name" not in log_data: + if "n_jobs" not in log_data or "task_name" not in log_data: job_ids.append(job_id) continue # expand ids + log_data: dict[str, str | None] for crab_num in range(1, int(log_data["n_jobs"]) + 1): job_ids.append(self.JobId(crab_num, log_data["task_name"], proj_dir)) job_objs = job_ids - return super(CrabJobManager, self)._apply_group( + return super()._apply_group( func, result_type, group_func, job_objs, - *args, - **kwargs # noqa + threads=threads, + callback=callback, + **kwargs, ) - def submit(self, job_file, job_files=None, proxy=None, instance=None, myproxy_username=None, - retries=0, retry_delay=3, silent=False): + def submit( # type: ignore[override] + self, + job_file: str | pathlib.Path, + *, + job_files: Sequence[str | pathlib.Path] | None = None, + proxy: str | None = None, + instance: str | None = None, + myproxy_username: str | None = None, + retries: int = 0, + retry_delay: int | float = 3, + silent: bool = False, + ) -> list[JobId] | None: # default arguments if proxy is None: proxy = self.proxy @@ -144,7 +184,7 @@ def submit(self, job_file, job_files=None, proxy=None, instance=None, myproxy_us instance = self.instance # get the job file location as the submission command is run it the same directory - job_file_dir, job_file_name = os.path.split(os.path.abspath(str(job_file))) + job_file_dir, job_file_name = os.path.split(os.path.abspath(get_path(job_file))) # define the actual submission in a loop to simplify retries while True: @@ -154,24 +194,29 @@ def submit(self, job_file, job_files=None, proxy=None, instance=None, myproxy_us cmd += ["--proxy", proxy] if instance: cmd += ["--instance", instance] - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run the command # crab prints everything to stdout - logger.debug("submit crab jobs with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, cwd=job_file_dir, env=self.cmssw_env) + logger.debug(f"submit crab jobs with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + cwd=job_file_dir, + env=self.cmssw_env, + ) # handle errors if code != 0: - logger.debug("submission of crab job '{}' failed with code {}:\n{}".format( - job_file, code, out)) + logger.debug(f"submission of crab job '{job_file}' failed with code {code}:\n{out}") # remove the project directory proj_dir = self._proj_dir_from_job_file(job_file, self.cmssw_env) if proj_dir and os.path.isdir(proj_dir): - logger.debug("removing crab project '{}' from previous attempt".format( - proj_dir)) + logger.debug(f"removing crab project '{proj_dir}' from previous attempt") shutil.rmtree(proj_dir) if retries > 0: @@ -182,8 +227,9 @@ def submit(self, job_file, job_files=None, proxy=None, instance=None, myproxy_us if silent: return None - raise Exception("submission of crab job '{}' failed with code {}:\n{}".format( - job_file, code, out)) + raise Exception( + f"submission of crab job '{job_file}' failed with code {code}:\n{out}", + ) # parse outputs task_name, log_file = None, None @@ -202,27 +248,33 @@ def submit(self, job_file, job_files=None, proxy=None, instance=None, myproxy_us break if not task_name: - raise Exception("no valid task name found in submission output:\n\n{}".format(out)) + raise Exception(f"no valid task name found in submission output:\n\n{out}") if not log_file: - raise Exception("no valid log file found in submission output:\n\n{}".format(out)) + raise Exception(f"no valid log file found in submission output:\n\n{out}") # create job ids with log data proj_dir = os.path.dirname(log_file) job_ids = self._job_ids_from_proj_dir(proj_dir) # checks - if not job_ids: - raise Exception("number of jobs not extractable from log file {}".format(log_file)) if job_files is not None and len(job_files) != len(job_ids): raise Exception( - "number of submited jobs ({}) does not match number of job files ({})".format( - len(job_ids), len(job_file)), + f"number of submited jobs ({len(job_ids)}) does not match number of job files " + f"({len(job_files)})", ) return job_ids - def cancel(self, proj_dir, job_ids=None, proxy=None, instance=None, myproxy_username=None, - silent=False): + def cancel( # type: ignore[override] + self, + proj_dir: str | pathlib.Path, + *, + job_ids: list[JobId] | None = None, + proxy: str | None = None, + instance: str | None = None, + myproxy_username: str | None = None, + silent: bool = False, + ) -> dict[JobId, None]: if job_ids is None: job_ids = self._job_ids_from_proj_dir(proj_dir) @@ -232,25 +284,38 @@ def cancel(self, proj_dir, job_ids=None, proxy=None, instance=None, myproxy_user cmd += ["--proxy", proxy] if instance: cmd += ["--instance", instance] - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("cancel crab job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, env=self.cmssw_env) + logger.debug(f"cancel crab job(s) with command '{cmd_str}'") + code, out, _ = interruptable_popen( + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + env=self.cmssw_env, + ) # check success if code != 0 and not silent: # crab prints everything to stdout raise Exception( - "cancellation of crab jobs from project '{}' failed with code {}:\n{}".format( - proj_dir, code, out), + f"cancellation of crab jobs from project '{proj_dir}' failed with code {code}:\n" + f"{out}", ) return {job_id: None for job_id in job_ids} - def cleanup(self, proj_dir, job_ids=None, proxy=None, instance=None, myproxy_username=None, - silent=False): + def cleanup( # type: ignore[override] + self, + proj_dir: str | pathlib.Path, + *, + job_ids: list[JobId] | None = None, + proxy: str | None = None, + instance: str | None = None, + myproxy_username: str | None = None, + silent: bool = False, + ) -> dict[JobId, None]: if job_ids is None: job_ids = self._job_ids_from_proj_dir(proj_dir) @@ -261,8 +326,17 @@ def cleanup(self, proj_dir, job_ids=None, proxy=None, instance=None, myproxy_use return {job_id: None for job_id in job_ids} - def query(self, proj_dir, job_ids=None, proxy=None, instance=None, myproxy_username=None, - skip_transfers=None, silent=False): + def query( # type: ignore[override] + self, + proj_dir: str | pathlib.Path, + *, + job_ids: list[JobId] | None = None, + proxy: str | None = None, + instance: str | None = None, + myproxy_username: str | None = None, + skip_transfers: bool | None = None, + silent: bool = False, + ) -> dict[JobId, dict[str, Any]] | None: proj_dir = str(proj_dir) log_data = self._parse_log_file(os.path.join(proj_dir, "crab.log")) if job_ids is None: @@ -278,12 +352,18 @@ def query(self, proj_dir, job_ids=None, proxy=None, instance=None, myproxy_usern cmd += ["--proxy", proxy] if instance: cmd += ["--instance", instance] - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("query crab job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, env=self.cmssw_env) + logger.debug(f"query crab job(s) with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + env=self.cmssw_env, + ) # handle errors if code != 0: @@ -291,8 +371,8 @@ def query(self, proj_dir, job_ids=None, proxy=None, instance=None, myproxy_usern return None # crab prints everything to stdout raise Exception( - "status query of crab jobs from project '{}' failed with code {}:\n{}".format( - proj_dir, code, out), + f"status query of crab jobs from project '{proj_dir}' failed with code {code}:\n" + f"{out}", ) # parse the output and extract the status per job @@ -310,7 +390,13 @@ def query(self, proj_dir, job_ids=None, proxy=None, instance=None, myproxy_usern return query_data @classmethod - def parse_query_output(cls, out, proj_dir, job_ids, skip_transfers=False): + def parse_query_output( + cls, + out: str, + proj_dir: str | pathlib.Path, + job_ids: list[JobId], + skip_transfers: bool = False, + ) -> dict[JobId, dict[str, Any]]: # parse values using compiled regexps cres = [ cls.query_user_cre, @@ -320,7 +406,7 @@ def parse_query_output(cls, out, proj_dir, job_ids, skip_transfers=False): cls.query_json_line_cre, cls.query_monitoring_url_cre, ] - values = len(cres) * [None] + values: list[str | None] = len(cres) * [None] # type: ignore[assignment] for line in out.replace("\r", "").split("\n"): for i, (cre, value) in enumerate(zip(cres, values)): if value: @@ -330,12 +416,21 @@ def parse_query_output(cls, out, proj_dir, job_ids, skip_transfers=False): values[i] = m.group(1) if all(values): break - # unpack - username, server_status, scheduler_id, scheduler_status, json_line, monitoring_url = values + ( + username, + server_status, + scheduler_id, + scheduler_status, + json_line, + monitoring_url, + ) = values # helper to build extra info - def extra(job_id, job_data=None): + def extra( + job_id: CrabJobManager.JobId, + job_data: dict[str, Any] | None = None, + ) -> dict[str, Any]: extra = {} if username and scheduler_id and job_data: extra["log_file"] = cls.log_file_pattern.format( @@ -360,8 +455,8 @@ def extra(job_id, job_data=None): if server_status not in accepted_server_states: s = ",".join(map("'{}'".format, accepted_server_states)) raise Exception( - "no per-job information available (yet?), which is only accepted if the crab " + - "server status is any of {}, but got '{}'".format(s, server_status), + "no per-job information available (yet?), which is only accepted if the crab " + f"server status is any of {s}, but got '{server_status}'", ) # interpret all jobs as pending return { @@ -372,8 +467,8 @@ def extra(job_id, job_data=None): # parse json data if not json_line: raise Exception( - "no per-job information available in status response, crab server " + - "status '{}', scheduler status '{}'".format(server_status, scheduler_status), + "no per-job information available in status response, crab server " + f"status '{server_status}', scheduler status '{scheduler_status}'", ) # map of crab job numbers to full ids for faster lookup @@ -410,7 +505,11 @@ def extra(job_id, job_data=None): return query_data @classmethod - def _proj_dir_from_job_file(cls, job_file, cmssw_env): + def _proj_dir_from_job_file( + cls, + job_file: str | pathlib.Path, + cmssw_env: MutableMapping[str, Any], + ) -> str | None: work_area = None request_name = None @@ -442,15 +541,21 @@ def _proj_dir_from_job_file(cls, job_file, cmssw_env): m = re.match(r"^CMSSW(_.+|)_(\d)+_\d+_\d+.*$", cmssw_env["CMSSW_VERSION"]) cmssw_major = int(m.group(2)) if m else None py_exec = "python3" if cmssw_major is None or cmssw_major >= 11 else "python" - cmd = """{} -c ' + cmd = f"""{py_exec} -c ' from os.path import join -with open("{}", "r") as f: +with open("{job_file}", "r") as f: mod = dict() exec(f.read(), mod) cfg = mod["cfg"] -print(join(cfg.General.workArea, "crab_" + cfg.General.requestName))'""".format(py_exec, job_file) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, env=cmssw_env) +print(join(cfg.General.workArea, "crab_" + cfg.General.requestName))'""" + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + env=cmssw_env, + ) if code == 0: path = out.strip().replace("\r\n", "\n").split("\n")[-1] path = os.path.expandvars(os.path.expanduser(path)) @@ -460,36 +565,41 @@ def _proj_dir_from_job_file(cls, job_file, cmssw_env): return None @classmethod - def _parse_log_file(cls, log_file): + def _parse_log_file(cls, log_file: str | pathlib.Path) -> dict[str, str | int]: log_file = os.path.expandvars(os.path.expanduser(str(log_file))) if not os.path.exists(log_file): - return None + raise FileNotFoundError(f"log file '{log_file}' does not exist") cres = [cls.log_n_jobs_cre, cls.log_task_name_cre, cls.log_disable_output_collection_cre] names = ["n_jobs", "task_name", "disable_output_collection"] - values = len(cres) * [None] + values: list[str | int | None] = len(cres) * [None] # type: ignore[assignment] + types = [int, str, bool] with open(log_file, "r") as f: for line in f.readlines(): - for i, (cre, value) in enumerate(zip(cres, values)): - if value: + for i, (cre, value, t) in enumerate(zip(cres, values, types)): + if value is not None: continue m = cre.match(line) if m: - values[i] = m.group(1) + values[i] = t(m.group(1)) if all(values): break - return dict(zip(names, values)) + return {n: v for n, v in zip(names, values) if v is not None} @classmethod - def _job_ids_from_proj_dir(cls, proj_dir, log_data=None): + def _job_ids_from_proj_dir( + cls, + proj_dir: str | pathlib.Path, + log_data: dict[str, str | int] | None = None, + ) -> list[JobId]: # read log data proj_dir = str(proj_dir) - if not log_data: + if log_data is None: log_data = cls._parse_log_file(os.path.join(proj_dir, "crab.log")) - if not log_data or "n_jobs" not in log_data or "task_name" not in log_data: - return None + if "n_jobs" not in log_data or "task_name" not in log_data: + raise ValueError(f"log data does not contain 'n_jobs' or 'task_name': {log_data}") # build and return ids return [ @@ -498,7 +608,7 @@ def _job_ids_from_proj_dir(cls, proj_dir, log_data=None): ] @classmethod - def map_status(cls, status, skip_transfers=False): + def map_status(cls, status: str | None, skip_transfers: bool = False) -> str: # see https://twiki.cern.ch/twiki/bin/view/CMSPublic/Crab3HtcondorStates if status in ("cooloff", "unsubmitted", "idle"): return cls.PENDING @@ -510,6 +620,8 @@ def map_status(cls, status, skip_transfers=False): return cls.FINISHED if status in ("killing", "failed", "held"): return cls.FAILED + + logger.debug(f"unknown crab job state '{status}'") return cls.FAILED @@ -521,10 +633,24 @@ class CrabJobFileFactory(BaseJobFileFactory): "custom_content", "absolute_paths", ] - def __init__(self, file_name="crab_job.py", executable=None, arguments=None, work_area=None, - request_name=None, input_files=None, output_files=None, storage_site=None, - output_lfn_base=None, vo_group=None, vo_role=None, custom_content=None, - absolute_paths=False, **kwargs): + def __init__( + self, + *, + file_name: str = "crab_job.py", + executable: str | None = None, + arguments: Sequence[str] | None = None, + work_area: str | None = None, + request_name: str | None = None, + input_files: dict[str, str | pathlib.Path | JobInputFile] | None = None, + output_files: list[str] | None = None, + storage_site: str | None = None, + output_lfn_base: str | None = None, + vo_group: str | None = None, + vo_role: str | None = None, + custom_content: str | Sequence[str] | None = None, + absolute_paths: bool = False, + **kwargs, + ) -> None: # get some default kwargs from the config cfg = Config.instance() default_dir = cfg.get_expanded( @@ -547,14 +673,14 @@ def __init__(self, file_name="crab_job.py", executable=None, arguments=None, wor cfg.find_option("job", "crab_job_file_dir_mkdtemp", "job_file_dir_mkdtemp"), ) - super(CrabJobFileFactory, self).__init__(**kwargs) + super().__init__(**kwargs) self.file_name = file_name self.executable = executable self.arguments = arguments self.work_area = work_area self.request_name = request_name - self.input_files = DeprecatedInputFiles(input_files or {}) + self.input_files = input_files or {} self.output_files = output_files or [] self.storage_site = storage_site self.output_lfn_base = output_lfn_base @@ -608,7 +734,7 @@ def __init__(self, file_name="crab_job.py", executable=None, arguments=None, wor ])), ]) - def create(self, **kwargs): + def create(self, **kwargs) -> tuple[str, CrabJobFileFactory.Config]: # merge kwargs and instance attributes c = self.get_config(**kwargs) @@ -620,18 +746,18 @@ def create(self, **kwargs): if not c.request_name: raise ValueError("request_name must not be empty") if "." in c.request_name: - raise ValueError("request_name should not contain '.', got {}".format(c.request_name)) + raise ValueError(f"request_name should not contain '.', got {c.request_name}") if len(c.request_name) > 100: raise ValueError( - "request_name must be less then 100 characters long, got {}: {}".format( - len(c.request_name), c.request_name), + f"request_name must be less then 100 characters long, got {len(c.request_name)}: " + f"{c.request_name}", ) if not c.output_lfn_base: raise ValueError("output_lfn_base must not be empty") if not c.storage_site: raise ValueError("storage_site must not be empty") if not isinstance(c.arguments, (list, tuple)): - raise ValueError("arguments must be a list, got '{}'".format(c.arguments)) + raise ValueError(f"arguments must be a list, got '{c.arguments}'") if "job_file" not in c.input_files: raise ValueError("an input file with key 'job_file' is required") @@ -722,7 +848,7 @@ def prepare_input(f): # inject arguments into the crab wrapper via render variables c.render_variables["crab_job_arguments_map"] = ("\n" + 8 * " ").join( - "['{}']=\"{}\"".format(i + 1, str(args)) + f"['{i + 1}']=\"{args}\"" for i, args in enumerate(c.arguments) ) @@ -777,7 +903,7 @@ def prepare_input(f): # note: they don't have to exist but crab requires a list of length totalUnits if not c.crab.Data.inputDataset: c.crab.Data.userInputFiles = [ - "input_{}.root".format(i + 1) + f"input_{i + 1}.root" for i in range(c.crab.Data.totalUnits) ] @@ -793,13 +919,18 @@ def prepare_input(f): # write the job file self.write_crab_config(job_file, c.crab, custom_content=c.custom_content) - logger.debug("created crab job file at '{}'".format(job_file)) + logger.debug(f"created crab job file at '{job_file}'") return job_file, c @classmethod - def write_crab_config(cls, job_file, crab_config, custom_content=None): - fmt_flat = lambda s: "\"{}\"".format(s) if isinstance(s, six.string_types) else str(s) + def write_crab_config( + cls, + job_file: str | pathlib.Path, + crab_config: DotDict, + custom_content: str | Sequence[str] | None = None, + ) -> None: + fmt_flat = lambda s: "\"{}\"".format(s) if isinstance(s, str) else str(s) with open(job_file, "w") as f: # header @@ -812,13 +943,11 @@ def write_crab_config(cls, job_file, crab_config, custom_content=None): # sections for section, cfg in crab_config.items(): - f.write("cfg.section_(\"{}\")\n".format(section)) + f.write(f"cfg.section_(\"{section}\")\n") # options for option, value in cfg.items(): if value == no_value: - raise Exception( - "cannot assign {} to crab config {}.{}".format(value, section, option), - ) + raise Exception(f"cannot assign {value} to crab config {section}.{option}") if value is None: continue value_str = ( @@ -826,15 +955,13 @@ def write_crab_config(cls, job_file, crab_config, custom_content=None): if isinstance(value, (list, tuple)) else fmt_flat(value) ) - f.write("cfg.{}.{} = {}\n".format(section, option, value_str)) + f.write(f"cfg.{section}.{option} = {value_str}\n") f.write("\n") # custom content - if isinstance(custom_content, six.string_types): - f.write(custom_content + "\n") - elif isinstance(custom_content, (list, tuple)): - for line in custom_content: - f.write(str(line) + "\n") + if custom_content is not None: + for line in make_list(custom_content or []): + f.write(f"{line}\n") class CMSJobDashboard(BaseJobDashboard): @@ -851,23 +978,38 @@ class CMSJobDashboard(BaseJobDashboard): SUCCESS = "success" FAILED = "failed" - tracking_url = "http://dashb-cms-job.cern.ch/dashboard/templates/task-analysis/#" + \ + tracking_url = ( + "http://dashb-cms-job.cern.ch/dashboard/templates/task-analysis/#" "table=Jobs&p=1&activemenu=2&refresh=60&tid={dashboard_task_id}" + ) persistent_attributes = ["task_id", "cms_user", "voms_user", "init_timestamp"] - def __init__(self, task, cms_user, voms_user, apmon_config=None, log_level="WARNING", - max_rate=20, task_type="analysis", site=None, executable="law", application=None, - application_version=None, submission_tool="law", submission_type="direct", - submission_ui=None, init_timestamp=None): - super(CMSJobDashboard, self).__init__(max_rate=max_rate) + def __init__( + self, + task: Task, + cms_user: str, + voms_user: str, + apmon_config: dict[str, Any] | None = None, + log_level: str = "WARNING", + max_rate: int = 20, + task_type: str = "analysis", + site: str | None = None, + executable: str = "law", + application: str | None = None, + application_version: str | int | None = None, + submission_tool: str = "law", + submission_type: str = "direct", + submission_ui: str | None = None, + init_timestamp: str | None = None, + ) -> None: + super().__init__(max_rate=max_rate) # setup the apmon thread try: self.apmon = Apmon(apmon_config, self.max_rate, log_level) except ImportError as e: - e.message += " (required for {})".format(self.__class__.__name__) - e.args = (e.message,) + e.args[1:] + e.args = (f"{e} (required for {self.__class__.__name__})",) + e.args[1:] raise e # get the task family for use as default application name @@ -893,49 +1035,56 @@ def __init__(self, task, cms_user, voms_user, apmon_config=None, log_level="WARN self.apmon.daemon = True self.apmon.start() - def __del__(self): - if getattr(self, "apmon", None) and self.apmon.is_alive(): - self.apmon.stop() - self.apmon.join() + def __del__(self) -> None: + if getattr(self, "apmon", None) is None or not self.apmon.is_alive(): + return + + self.apmon.stop() + self.apmon.join() @classmethod - def create_timestamp(cls): + def create_timestamp(cls) -> str: return time.strftime("%y%m%d_%H%M%S") @classmethod - def create_dashboard_task_id(cls, task_id, cms_user, timestamp=None): - if not timestamp: + def create_dashboard_task_id( + cls, + task_id: str, + cms_user: str, + timestamp: str | None = None, + ) -> str: + if timestamp is None: timestamp = cls.create_timestamp() - return "{}:{}_{}".format(timestamp, cms_user, task_id) + return f"{timestamp}:{cms_user}_{task_id}" @classmethod - def create_dashboard_job_id(cls, job_num, job_id, attempt=0): - return "{}_{}_{}".format(job_num, job_id, attempt) + def create_dashboard_job_id(cls, job_num: str, job_id: str, attempt: int = 0) -> str: + return f"{job_num}_{job_id}_{attempt}" @classmethod - def params_from_status(cls, dashboard_status, fail_code=1): + def params_from_status(cls, dashboard_status: str, fail_code: int = 1) -> dict[str, Any]: if dashboard_status == cls.PENDING: return {"StatusValue": "pending", "SyncCE": None} - elif dashboard_status == cls.RUNNING: + if dashboard_status == cls.RUNNING: return {"StatusValue": "running"} - elif dashboard_status == cls.CANCELLED: + if dashboard_status == cls.CANCELLED: return {"StatusValue": "cancelled", "SyncCE": None} - elif dashboard_status == cls.POSTPROC: + if dashboard_status == cls.POSTPROC: return {"StatusValue": "running", "JobExitCode": 0} - elif dashboard_status == cls.SUCCESS: + if dashboard_status == cls.SUCCESS: return {"StatusValue": "success", "JobExitCode": 0} - elif dashboard_status == cls.FAILED: + if dashboard_status == cls.FAILED: return {"StatusValue": "failed", "JobExitCode": fail_code} - else: - raise ValueError("invalid dashboard status '{}'".format(dashboard_status)) + + raise ValueError(f"invalid dashboard status '{dashboard_status}'") @classmethod - def map_status(cls, job_status, event): + def map_status(cls, job_status: str, event: str) -> str | None: # when starting with "status.", event must end with the job status if event.startswith("status.") and event.split(".", 1)[-1] != job_status: - raise ValueError("event '{}' does not match job status '{}'".format(event, job_status)) + raise ValueError(f"event '{event}' does not match job status '{job_status}'") - status = lambda attr: "status.{}".format(getattr(BaseJobManager, attr)) + status = lambda attr: f"status.{getattr(BaseJobManager, attr)}" return { "action.submit": cls.PENDING, @@ -946,47 +1095,61 @@ def map_status(cls, job_status, event): status("FINISHED"): cls.SUCCESS, }.get(event) - def remote_hook_file(self): + def remote_hook_file(self) -> str: return law.util.rel_path(__file__, "scripts", "cmsdashb_hooks.sh") - def remote_hook_data(self, job_num, attempt): - data = [ - "task_id='{}'".format(self.task_id), - "cms_user='{}'".format(self.cms_user), - "voms_user='{}'".format(self.voms_user), - "init_timestamp='{}'".format(self.init_timestamp), - "job_num={}".format(job_num), - "attempt={}".format(attempt), - ] + def remote_hook_data(self, job_num: int, attempt: int) -> dict[str, Any]: + data = { + "task_id": self.task_id, + "cms_user": self.cms_user, + "voms_user": self.voms_user, + "init_timestamp": self.init_timestamp, + "job_num": job_num, + "attempt": attempt, + } if self.site: - data.append("site='{}'".format(self.site)) + data["site"] = self.site return data - def create_tracking_url(self): + def create_tracking_url(self) -> str: dashboard_task_id = self.create_dashboard_task_id(self.task_id, self.cms_user, self.init_timestamp) return self.tracking_url.format(dashboard_task_id=dashboard_task_id) - def create_message(self, job_data, event, job_num, attempt=0, custom_params=None, **kwargs): + def create_message( + self, + job_data, + event, + job_num, + attempt=0, + custom_params=None, + **kwargs, + ) -> tuple[str, str, dict[str, Any]] | None: # we need the voms user, which must start with "/CN=" voms_user = self.voms_user if not voms_user: - return + return None if not voms_user.startswith("/CN="): voms_user = "/CN=" + voms_user # map to job status to a valid dashboard status dashboard_status = self.map_status(job_data.get("status"), event) if not dashboard_status: - return + return None # build the dashboard task id - dashboard_task_id = self.create_dashboard_task_id(self.task_id, self.cms_user, - self.init_timestamp) + dashboard_task_id = self.create_dashboard_task_id( + self.task_id, + self.cms_user, + self.init_timestamp, + ) # build the id of the particular job - dashboard_job_id = self.create_dashboard_job_id(job_num, job_data["job_id"], - attempt=attempt) + dashboard_job_id = self.create_dashboard_job_id( + job_num, + job_data["job_id"], + attempt=attempt, + ) # build the parameters to send params = { @@ -1017,13 +1180,13 @@ def create_message(self, job_data, event, job_num, attempt=0, custom_params=None params.update(custom_params) # finally filter None's and convert everything to strings - params = {key: str(value) for key, value in six.iteritems(params) if value is not None} + params = {key: str(value) for key, value in params.items() if value is not None} return (dashboard_task_id, dashboard_job_id, params) - @BaseJobDashboard.cache_by_status - def publish(self, *args, **kwargs): - message = self.create_message(*args, **kwargs) + @BaseJobDashboard.cache_by_status # type: ignore[misc] + def publish(self, job_data: JobData, event: str, job_num: int, *args, **kwargs) -> None: # type: ignore[override] # noqa + message = self.create_message(job_data, event, job_num, *args, **kwargs) if message: self.apmon.send(*message) @@ -1041,10 +1204,15 @@ class Apmon(threading.Thread): }, } - def __init__(self, config=None, max_rate=20, log_level="INFO"): - super(Apmon, self).__init__() + def __init__( + self, + config: dict[str, dict[str, Any]] | None = None, + max_rate: int = 20, + log_level: str = "INFO", + ) -> None: + super().__init__() - import apmon + import apmon # type: ignore[import-untyped, import-not-found] log_level = getattr(apmon.Logger, log_level.upper()) self._apmon = apmon.ApMon(config or self.default_config, log_level) self._apmon.maxMsgRate = int(max_rate * 1.5) @@ -1054,19 +1222,19 @@ def __init__(self, config=None, max_rate=20, log_level="INFO"): value["INSTANCE_ID"] = value["INSTANCE_ID"] & 0x7fffffff self._max_rate = max_rate - self._queue = six.moves.queue.Queue() + self._queue: queue.Queue = queue.Queue() self._stop_event = threading.Event() - def send(self, *args, **kwargs): + def send(self, *args, **kwargs) -> None: self._queue.put((args, kwargs)) - def _send(self, *args, **kwargs): + def _send(self, *args, **kwargs) -> None: self._apmon.sendParameters(*args, **kwargs) - def stop(self): + def stop(self) -> None: self._stop_event.set() - def run(self): + def run(self) -> None: while True: # handling stopping self._stop_event.wait(0.5) diff --git a/law/contrib/cms/sandbox.py b/law/contrib/cms/sandbox.py index de57d700..996f552b 100644 --- a/law/contrib/cms/sandbox.py +++ b/law/contrib/cms/sandbox.py @@ -4,25 +4,28 @@ CMS related sandbox implementations. """ -__all__ = ["CMSSWSandbox"] +from __future__ import annotations +__all__ = ["CMSSWSandbox"] import os +import pathlib +import pickle import collections -import six - +from law.task.proxy import ProxyCommand from law.sandbox.base import _current_sandbox from law.sandbox.bash import BashSandbox from law.util import ( tmp_file, interruptable_popen, quote_cmd, flatten, makedirs, rel_path, law_home_path, create_hash, ) +from law._types import Any class CMSSWSandbox(BashSandbox): - sandbox_type = "cmssw" + sandbox_type: str = "cmssw" # type for sandbox variables # (names corresond to variables used in setup_cmssw.sh script) @@ -32,10 +35,10 @@ class CMSSWSandbox(BashSandbox): ) @classmethod - def create_variables(cls, s): + def create_variables(cls, s: str) -> Variables: # input format: [::[::...]] if not s: - raise ValueError("cannot create {} variables from input '{}'".format(cls.__name__, s)) + raise ValueError(f"cannot create {cls.__name__} variables from input '{s}'") # split values values = {} @@ -44,12 +47,10 @@ def create_variables(cls, s): values["version"] = part continue if "=" not in part: - raise ValueError( - "wrong format, part '{}' at index {} does not contain a '='".format(part, i), - ) + raise ValueError(f"wrong format, part '{part}' at index {i} does not contain a '='") field, value = part.split("=", 1) if field not in cls.Variables._fields: - raise KeyError("unknown variable name '{}' at index {}".format(field, i)) + raise KeyError(f"unknown variable name '{field}' at index {i}") values[field] = value # special treatments @@ -60,12 +61,12 @@ def create_variables(cls, s): values["dir"] = expand(values["dir"]) else: h = create_hash((cls.sandbox_type, values["version"], values.get("setup"))) - values["dir"] = law_home_path("cms", "cmssw", "{}_{}".format(values["version"], h)) + values["dir"] = law_home_path("cms", "cmssw", f"{values['version']}_{h}") return cls.Variables(*[values.get(field, "") for field in cls.Variables._fields]) - def __init__(self, *args, **kwargs): - super(CMSSWSandbox, self).__init__(*args, **kwargs) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # parse name into variables self.variables = self.create_variables(self.name) @@ -75,11 +76,11 @@ def __init__(self, *args, **kwargs): h = create_hash((self.sandbox_type, self.env_cache_key)) self.env_cache_path = law_home_path( "cms", - "{}_cache".format(self.sandbox_type), - "{}_{}.pkl".format(self.variables.version, h), + f"{self.sandbox_type}_cache", + f"{self.variables.version}_{h}.pkl", ) - def is_active(self): + def is_active(self) -> bool: # check if any current sandbox matches the version, setup and dir of this one for key in _current_sandbox: _type, name = self.split_key(key) @@ -91,23 +92,23 @@ def is_active(self): return False - def get_custom_config_section_postfix(self): + def get_custom_config_section_postfix(self) -> str: return self.variables.version @property - def env_cache_key(self): + def env_cache_key(self) -> tuple[str, str, str]: # type: ignore[override] return self.variables[:3] @property - def script(self): + def script(self) -> str: return rel_path(__file__, "scripts", "setup_cmssw.sh") - def create_env(self): + def create_env(self) -> dict[str, Any]: # strategy: create a tempfile, let python dump its full env in a subprocess and load the # env file again afterwards # helper to write the env - def write_env(path): + def write_env(path: str | pathlib.Path) -> None: # get the bash command bash_cmd = self._bash_cmd() @@ -116,18 +117,20 @@ def write_env(path): # build script variable exports export_cmds = self._build_setup_cmds(collections.OrderedDict( - ("LAW_CMSSW_{}".format(attr.upper()), value) + (f"LAW_CMSSW_{attr.upper()}", value) for attr, value in zip(self.variables._fields, self.variables) )) # build the python command that dumps the environment - py_cmd = "import os,pickle;" \ - + "pickle.dump(dict(os.environ),open('{}','wb'),protocol=2)".format(path) + py_cmd = ( + "import os,pickle;" + f"pickle.dump(dict(os.environ),open('{path}','wb'),protocol=2)" + ) # build the full command cmd = quote_cmd(bash_cmd + ["-c", " && ".join(flatten( export_cmds, - "source \"{}\" \"\"".format(self.script), + f"source \"{self.script}\" \"\"", setup_cmds, quote_cmd(["python", "-c", py_cmd]), ))]) @@ -135,19 +138,15 @@ def write_env(path): # run it returncode = interruptable_popen(cmd, shell=True, executable="/bin/bash")[0] if returncode != 0: - raise Exception("bash sandbox env loading failed with exit code {}".format( - returncode)) + raise Exception(f"bash sandbox env loading failed with exit code {returncode}") # helper to load the env - def load_env(path): - pickle_kwargs = {"encoding": "utf-8"} if six.PY3 else {} + def load_env(path: str | pathlib.Path) -> dict[str, Any]: with open(path, "rb") as f: try: - return collections.OrderedDict(six.moves.cPickle.load(f, **pickle_kwargs)) + return dict(pickle.load(f, encoding="utf-8")) except Exception as e: - raise Exception( - "env deserialization of sandbox {} failed: {}".format(self, e), - ) + raise Exception(f"env deserialization of sandbox {self} failed: {e}") # use the cache path if set if self.env_cache_path: @@ -171,7 +170,7 @@ def load_env(path): return env - def cmd(self, proxy_cmd): + def cmd(self, proxy_cmd: ProxyCommand) -> str: # environment variables to set env = self._get_env() @@ -189,7 +188,7 @@ def cmd(self, proxy_cmd): # build script variable exports export_cmds = self._build_setup_cmds(collections.OrderedDict( - ("LAW_CMSSW_{}".format(attr.upper()), value) + (f"LAW_CMSSW_{attr.upper()}", value) for attr, value in zip(self.variables._fields, self.variables) )) @@ -200,7 +199,7 @@ def cmd(self, proxy_cmd): # build the final command cmd = quote_cmd(bash_cmd + ["-c", " && ".join(flatten( export_cmds, - "source \"{}\" \"\"".format(self.script), + f"source \"{self.script}\" \"\"", setup_cmds, proxy_cmd.build(), ))]) diff --git a/law/contrib/cms/tasks.py b/law/contrib/cms/tasks.py index a883cfef..01e7797e 100644 --- a/law/contrib/cms/tasks.py +++ b/law/contrib/cms/tasks.py @@ -5,18 +5,19 @@ https://home.cern/about/experiments/cms """ -__all__ = ["BundleCMSSW"] +from __future__ import annotations +__all__ = ["BundleCMSSW"] import os import subprocess -from abc import abstractmethod - -import luigi +import pathlib +import abc +import luigi # type: ignore[import-untyped] from law.task.base import Task -from law.target.file import get_path +from law.target.file import get_path, FileSystemFileTarget from law.target.local import LocalFileTarget from law.parameter import NO_STR, CSVParameter from law.decorator import log @@ -46,17 +47,17 @@ class BundleCMSSW(Task): cmssw_checksumming = True - def __init__(self, *args, **kwargs): - super(BundleCMSSW, self).__init__(*args, **kwargs) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - self._checksum = None + self._checksum: str | None = None - @abstractmethod - def get_cmssw_path(self): - return + @abc.abstractmethod + def get_cmssw_path(self) -> str | pathlib.Path | LocalFileTarget: + ... @property - def checksum(self): + def checksum(self) -> None | str: if not self.cmssw_checksumming: return None @@ -70,10 +71,15 @@ def checksum(self): ] if self.exclude != NO_STR: cmd += [self.exclude] - cmd = quote_cmd(cmd) - - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE) + _cmd = quote_cmd(cmd) + + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + _cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + ) if code != 0: raise Exception("cmssw checksum calculation failed") @@ -81,20 +87,20 @@ def checksum(self): return self._checksum - def output(self): + def output(self) -> FileSystemFileTarget: base = os.path.basename(get_path(self.get_cmssw_path())) if self.checksum: - base += "{}.".format(self.checksum) + base += f"{self.checksum}." base = os.path.abspath(os.path.expandvars(os.path.expanduser(base))) - return LocalFileTarget("{}.tgz".format(base)) + return LocalFileTarget(f"{base}.tgz") @log - def run(self): + def run(self) -> None: with self.output().localize("w") as tmp: with self.publish_step("bundle CMSSW ..."): self.bundle(tmp.path) - def get_cmssw_bundle_command(self, dst_path): + def get_cmssw_bundle_command(self, dst_path: str | pathlib.Path | LocalFileTarget) -> list[str]: return [ rel_path(__file__, "scripts", "bundle_cmssw.sh"), get_path(self.get_cmssw_path()), @@ -103,7 +109,7 @@ def get_cmssw_bundle_command(self, dst_path): " ".join(self.include), ] - def bundle(self, dst_path): + def bundle(self, dst_path: str | pathlib.Path | LocalFileTarget) -> None: cmd = self.get_cmssw_bundle_command(dst_path) code = interruptable_popen(quote_cmd(cmd), shell=True, executable="/bin/bash")[0] if code != 0: diff --git a/law/contrib/cms/util.py b/law/contrib/cms/util.py index 3fb56b97..796dc758 100644 --- a/law/contrib/cms/util.py +++ b/law/contrib/cms/util.py @@ -4,13 +4,15 @@ CMS-related utilities. """ -__all__ = ["Site", "lfn_to_pfn", "renew_vomsproxy", "delegate_myproxy"] +from __future__ import annotations +__all__ = ["Site", "lfn_to_pfn", "renew_vomsproxy", "delegate_myproxy"] import os import law + law.contrib.load("wlcg") @@ -52,14 +54,14 @@ class Site(object): "us": "cmsxrootd.fnal.gov", } - def __init__(self, name=None): - super(Site, self).__init__() + def __init__(self, name: str | None = None) -> None: + super().__init__() # site name cache - self.name = name or self.get_name_from_env() + self.name: str | None = self.get_name_from_env() if name is None else name @classmethod - def get_name_from_env(cls): + def get_name_from_env(cls) -> str | None: """ Tries to extract the site name from environment variables. Returns the name on succcess and *None* otherwise. @@ -71,63 +73,75 @@ def get_name_from_env(cls): return None @property - def info(self): + def info(self) -> tuple[str, str, str] | tuple[None, None, None]: """ Tier, country and locality information in a 3-tuple, e.g. ``("T2", "DE", "RWTH")``. """ - return self.name and self.name.split("", 2) + if self.name is not None: + info = self.name.split("_", 2) + if len(info) != 3: + raise ValueError(f"invalid site name: {self.name}") + return tuple(info) # type: ignore[return-value] + + return (None, None, None) @property - def tier(self): + def tier(self) -> str | None: """ The tier of the site, e.g. ``T2``. """ - return self.name and self.info[0] + info = self.info + return None if self.info is None else info[0] @property - def country(self): + def country(self) -> str | None: """ The country of the site, e.g. ``DE``. """ - return self.name and self.info[1] + info = self.info + return None if self.info is None else info[1] @property def locality(self): """ The locality of the site, e.g. ``RWTH``. """ - return self.name and self.info[2] + info = self.info + return None if self.info is None else info[2] @property - def redirector(self): + def redirector(self) -> str: """ The XRD redirector that should be used on this site. For more information on XRD, see `this link `_. """ - return self.redirectors.get(self.country.lower(), self.redirectors["global"]) + country = self.country + if country in self.redirectors: + return self.redirectors[country] + return self.redirectors["global"] -def lfn_to_pfn(lfn, redirector="global"): +def lfn_to_pfn(lfn: str, redirector: str = "global") -> str: """ Converts a logical file name *lfn* to a physical file name *pfn* using a *redirector*. Valid values for *redirector* are defined by :py:attr:`Site.redirectors`. """ if redirector not in Site.redirectors: - raise ValueError("unknown redirector: {}".format(redirector)) + raise ValueError(f"unknown redirector: {redirector}") - return "root://{}/{}".format(Site.redirectors[redirector], lfn) + return f"root://{Site.redirectors[redirector]}/{lfn}" -def renew_vomsproxy(**kwargs): +def renew_vomsproxy(**kwargs) -> str | None: """ Renews a VOMS proxy in the exact same way that :py:func:`law.wlcg.renew_vomsproxy` does, but with the *vo* attribute set to ``"cms"`` by default. """ kwargs.setdefault("vo", "cms") - return law.wlcg.renew_vomsproxy(**kwargs) + return law.wlcg.renew_vomsproxy(**kwargs) # type: ignore[attr-defined] -def delegate_myproxy(**kwargs): +def delegate_myproxy(**kwargs) -> str | None: """ Delegates a X509 proxy to a myproxy server in the exact same way that :py:func:`law.wlcg.delegate_myproxy` does, but with the *retrievers* argument set to a value @@ -138,4 +152,4 @@ def delegate_myproxy(**kwargs): "/DC=ch/DC=cern/OU=computers/CN=crab-(preprod|prod|dev)-tw(01|02|03).cern.ch|/DC=ch/DC=cern/OU=computers/CN=stefanov(m|m2).cern.ch|/DC=ch/DC=cern/OU=computers/CN=dciangot-tw.cern.ch", # noqa ) kwargs.setdefault("vo", "cms") - return law.wlcg.delegate_myproxy(**kwargs) + return law.wlcg.delegate_myproxy(**kwargs) # type: ignore[attr-defined] diff --git a/law/contrib/cms/workflow.py b/law/contrib/cms/workflow.py index a17837a5..867ebe2b 100644 --- a/law/contrib/cms/workflow.py +++ b/law/contrib/cms/workflow.py @@ -5,23 +5,25 @@ https://twiki.cern.ch/twiki/bin/view/CMSPublic/SWGuideCrab. """ -__all__ = ["CrabWorkflow"] +from __future__ import annotations +__all__ = ["CrabWorkflow"] +import pathlib import uuid -from abc import abstractmethod -from collections import OrderedDict +import abc -import law from law.config import Config -from law.workflow.remote import BaseRemoteWorkflow, BaseRemoteWorkflowProxy +from law.workflow.remote import BaseRemoteWorkflow, BaseRemoteWorkflowProxy, JobData from law.job.base import JobArguments, JobInputFile from law.target.file import get_path, get_scheme, remove_scheme, FileSystemDirectoryTarget -from law.target.local import LocalDirectoryTarget +from law.target.local import LocalDirectoryTarget, LocalFileTarget from law.task.proxy import ProxyCommand -from law.util import no_value, law_src_path, merge_dicts, DotDict, human_duration +from law.util import no_value, law_src_path, merge_dicts, human_duration, DotDict, InsertableDict from law.logger import get_logger +from law._types import Any, Type +from law.contrib.wlcg import check_vomsproxy_validity, get_myproxy_info from law.contrib.cms.job import CrabJobManager, CrabJobFileFactory from law.contrib.cms.util import renew_vomsproxy, delegate_myproxy @@ -31,25 +33,25 @@ class CrabWorkflowProxy(BaseRemoteWorkflowProxy): - workflow_type = "crab" + workflow_type: str = "crab" # job script error codes are not transferred, so disable them job_error_messages = {} - def create_job_manager(self, **kwargs): + def create_job_manager(self, **kwargs) -> CrabJobManager: return self.task.crab_create_job_manager(**kwargs) - def setup_job_manager(self): + def setup_job_manager(self) -> dict[str, Any]: cfg = Config.instance() password_file = cfg.get_expanded("job", "crab_password_file") # ensure a VOMS proxy exists - if not law.wlcg.check_vomsproxy_validity(): + if not check_vomsproxy_validity(): print("renew voms-proxy") renew_vomsproxy(password_file=password_file) # ensure that it has been delegated to the myproxy server - info = law.wlcg.get_myproxy_info(silent=True) + info = get_myproxy_info(silent=True) delegate = False if not info: delegate = True @@ -59,10 +61,9 @@ def setup_job_manager(self): elif "timeleft" not in info: logger.warning("field 'timeleft' not in myproxy info") delegate = True - elif info["timeleft"] < 86400: - logger.warning("myproxy lifetime below 24h ({})".format( - human_duration(seconds=info["timeleft"]), - )) + elif info["timeleft"] < 86400: # type: ignore[operator] + timeleft = human_duration(seconds=info["timeleft"]) + logger.warning(f"myproxy lifetime below 24h ({timeleft})") delegate = True # actual delegation @@ -70,18 +71,21 @@ def setup_job_manager(self): print("delegate to myproxy server") myproxy_username = delegate_myproxy(password_file=password_file) else: - myproxy_username = info["username"] + myproxy_username = info["username"] # type: ignore[index, assignment] return {"myproxy_username": myproxy_username} - def create_job_file_factory(self, **kwargs): + def create_job_file_factory(self, **kwargs) -> CrabJobFileFactory: return self.task.crab_create_job_file_factory(**kwargs) - def create_job_file(self, submit_jobs): + def create_job_file_group( + self, + submit_jobs: dict[int, list[int]], + ) -> dict[str, str | pathlib.Path | CrabJobFileFactory.Config | None]: task = self.task # create the config - c = self.job_file_factory.get_config() + c = self.job_file_factory.get_config() # type: ignore[union-attr] c.input_files = {} c.output_files = [] c.render_variables = {} @@ -107,20 +111,25 @@ def create_job_file(self, submit_jobs): exclude_global_args=["workers"], ) proxy_cmd.add_arg("--local-scheduler", "True", overwrite=True) - for key, value in OrderedDict(task.crab_cmdline_args()).items(): + for key, value in dict(task.crab_cmdline_args()).items(): proxy_cmd.add_arg(key, value, overwrite=True) # job script arguments per job number c.arguments = [] for job_num, branches in submit_jobs.items(): + dashboard_data = None + if self.dashboard: + dashboard_data = self.dashboard.remote_hook_data( + job_num, + self.job_data.attempts.get(job_num, 0), + ) job_args = JobArguments( task_cls=task.__class__, task_params=proxy_cmd.build(skip_run=True), branches=branches, workers=task.job_workers, auto_retry=False, - dashboard_data=self.dashboard.remote_hook_data( - job_num, self.job_data.attempts.get(job_num, 0)), + dashboard_data=dashboard_data, ) c.arguments.append(job_args.join()) @@ -135,7 +144,7 @@ def create_job_file(self, submit_jobs): if not isinstance(stageout_location, (list, tuple)) or len(stageout_location) != 2: raise ValueError( "the return value of crab_stageout_location() is expected to be a 2-tuple, got " - "'{}'".format(stageout_location), + f"'{stageout_location}'", ) c.storage_site, c.output_lfn_base = stageout_location @@ -150,7 +159,7 @@ def create_job_file(self, submit_jobs): c.input_files["stageout_file"] = stageout_file # does the dashboard have a hook file? - dashboard_file = self.dashboard.remote_hook_file() + dashboard_file = self.dashboard.remote_hook_file() if self.dashboard else None if dashboard_file: c.input_files["dashboard_file"] = dashboard_file @@ -163,22 +172,22 @@ def create_job_file(self, submit_jobs): c = task.crab_job_config(c, submit_jobs) # build the job file and get the sanitized config - job_file, c = self.job_file_factory(**c.__dict__) + job_file, c = self.job_file_factory(**c.__dict__) # type: ignore[misc] # return job and log file entry # (the latter is None but will be synced from query data) return {"job": job_file, "config": c, "log": None} - def _status_error_pairs(self, job_num, job_data): - pairs = super(CrabWorkflowProxy, self)._status_error_pairs(job_num, job_data) + def _status_error_pairs(self, job_num: int, job_data: JobData) -> InsertableDict: + pairs = super()._status_error_pairs(job_num, job_data) # add site history pairs.insert_before("log", "site history", job_data["extra"].get("site_history", no_value)) return pairs - def destination_info(self): - info = super(CrabWorkflowProxy, self).destination_info() + def destination_info(self) -> InsertableDict: + info = super().destination_info() info = self.task.crab_destination_info(info) @@ -193,19 +202,19 @@ class CrabWorkflow(BaseRemoteWorkflow): crab_job_manager_defaults = None crab_job_file_factory_defaults = None - crab_job_kwargs = [] + crab_job_kwargs: list[str] = [] crab_job_kwargs_submit = None crab_job_kwargs_cancel = None crab_job_kwargs_cleanup = None crab_job_kwargs_query = None exclude_params_branch = set() - exclude_params_crab_workflow = set() + exclude_params_crab_workflow: set[str] = set() exclude_index = True - @abstractmethod - def crab_stageout_location(self): + @abc.abstractmethod + def crab_stageout_location(self) -> tuple[str, str]: """ Hook to define both the "Site.storageSite" and "Data.outLFNDirBase" settings in a 2-tuple, i.e., the name of the storage site to use and the base directory for crab's own output @@ -214,24 +223,24 @@ def crab_stageout_location(self): In case this is not used, the choice of the output base has no affect, but is still required for crab's job submission to work. """ - return + ... - @abstractmethod - def crab_output_directory(self): + @abc.abstractmethod + def crab_output_directory(self) -> FileSystemDirectoryTarget: """ Hook to define the location of submission output files, such as the json files containing job data. This method should return a :py:class:`FileSystemDirectoryTarget`. """ - return + ... - def crab_request_name(self, submit_jobs): + def crab_request_name(self, submit_jobs: dict[int, list[int]]) -> str: """ Returns a random name for a request, i.e., the project directory inside the crab job working area. """ - return "{}_{}".format(self.live_task_id, str(uuid.uuid4())[:8]) + return f"{self.live_task_id}_{str(uuid.uuid4())[:8]}" - def crab_work_area(self): + def crab_work_area(self) -> str | LocalDirectoryTarget: """ Returns the location of the crab working area, defaulting to the value of :py:meth:`crab_output_directory` in case it refers to a local directory. When *None*, the @@ -253,13 +262,13 @@ def crab_work_area(self): # relative to the job file directory return "" - def crab_job_file(self): + def crab_job_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile: """ Hook to return the location of the job file that is executed on job nodes. """ return JobInputFile(law_src_path("job", "law_job.sh")) - def crab_bootstrap_file(self): + def crab_bootstrap_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: """ Hook to define the location of an optional, so-called bootstrap file that is sent alongside jobs and called prior to the actual job payload. It is meant to run a custom setup routine @@ -267,7 +276,7 @@ def crab_bootstrap_file(self): """ return None - def crab_stageout_file(self): + def crab_stageout_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: """ Hook to define the location of an optional, so-called stageout file that is sent alongside jobs and called after to the actual job payload. It is meant to run a custom output stageout @@ -275,46 +284,46 @@ def crab_stageout_file(self): """ return None - def crab_workflow_requires(self): + def crab_workflow_requires(self) -> DotDict: """ Hook to define requirements for the workflow itself and that need to be resolved before any submission can happen. """ return DotDict() - def crab_output_postfix(self): + def crab_output_postfix(self) -> str: """ Hook to define the postfix of outputs, for instance such that workflows with different parameters do not write their intermediate job status information into the same json file. """ return "" - def crab_output_uri(self): + def crab_output_uri(self) -> str: """ Hook to return the URI of the remote crab output directory. """ - return self.crab_output_directory().uri() + return self.crab_output_directory().uri(return_all=False) # type: ignore[return-value] - def crab_job_manager_cls(self): + def crab_job_manager_cls(self) -> Type[CrabJobManager]: """ Hook to define a custom job managet class to use. """ return CrabJobManager - def crab_create_job_manager(self, **kwargs): + def crab_create_job_manager(self, **kwargs) -> CrabJobManager: """ Hook to configure how the underlying job manager is instantiated and configured. """ kwargs = merge_dicts(self.crab_job_manager_defaults, kwargs) return self.crab_job_manager_cls()(**kwargs) - def crab_job_file_factory_cls(self): + def crab_job_file_factory_cls(self) -> Type[CrabJobFileFactory]: """ Hook to define a custom job file factory class to use. """ return CrabJobFileFactory - def crab_create_job_file_factory(self, **kwargs): + def crab_create_job_file_factory(self, **kwargs) -> CrabJobFileFactory: """ Hook to configure how the underlying job file factory is instantiated and configured. """ @@ -322,20 +331,25 @@ def crab_create_job_file_factory(self, **kwargs): kwargs = merge_dicts({}, self.crab_job_file_factory_defaults, kwargs) return self.crab_job_file_factory_cls()(**kwargs) - def crab_job_config(self, config, submit_jobs): + def crab_job_config( + self, + config: CrabJobFileFactory.Config, + submit_jobs: dict[int, list[int]], + ) -> CrabJobFileFactory.Config: """ Hook to inject custom settings into the job *config*, which is an instance of the :py:attr:`Config` class defined inside the job manager. """ return config - def crab_check_job_completeness(self): + def crab_check_job_completeness(self) -> bool: """ - Hook to define whether + Hook to define whether after job report successful completion, the job manager should check + the completion status of the branch tasks run by the finished jobs. """ return False - def crab_check_job_completeness_delay(self): + def crab_check_job_completeness_delay(self) -> float | int: """ Grace period before :py:meth:`crab_check_job_completeness` is called to ensure that output files are accessible. Especially useful on distributed file systems with possibly @@ -343,13 +357,13 @@ def crab_check_job_completeness_delay(self): """ return 0.0 - def crab_cmdline_args(self): + def crab_cmdline_args(self) -> dict[str, str]: """ Hook to add additional cli parameters to "law run" commands executed on job nodes. """ return {} - def crab_destination_info(self, info): + def crab_destination_info(self, info: InsertableDict) -> InsertableDict: """ Hook to add additional information behind each job status query line by extending an *info* dictionary whose values will be shown separated by comma. diff --git a/law/contrib/glite/__init__.py b/law/contrib/glite/__init__.py index 83bb85a8..64ffe698 100644 --- a/law/contrib/glite/__init__.py +++ b/law/contrib/glite/__init__.py @@ -7,6 +7,9 @@ __all__ = ["GLiteJobManager", "GLiteJobFileFactory", "GLiteWorkflow"] +# dependencies to other contrib modules +import law +law.contrib.load("wlcg") # provisioning imports from law.contrib.glite.job import GLiteJobManager, GLiteJobFileFactory diff --git a/law/contrib/glite/job.py b/law/contrib/glite/job.py index 60d9410c..354c8299 100644 --- a/law/contrib/glite/job.py +++ b/law/contrib/glite/job.py @@ -4,6 +4,8 @@ Simple gLite job manager. See https://wiki.italiangrid.it/twiki/bin/view/CREAM/UserGuide. """ +from __future__ import annotations + __all__ = ["GLiteJobManager", "GLiteJobFileFactory"] import os @@ -11,13 +13,15 @@ import time import re import random +import pathlib import subprocess from law.config import Config -from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile, DeprecatedInputFiles +from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile from law.target.file import get_path from law.util import interruptable_popen, make_list, make_unique, quote_cmd from law.logger import get_logger +from law._types import Any, Sequence logger = get_logger(__name__) @@ -36,14 +40,27 @@ class GLiteJobManager(BaseJobManager): submission_job_id_cre = re.compile(r"^https?\:\/\/.+\:\d+\/.+") status_block_cre = re.compile(r"(\w+)\s*\=\s*\[([^\]]*)\]") - def __init__(self, ce=None, delegation_id=None, threads=1): - super(GLiteJobManager, self).__init__() + def __init__( + self, + ce: str | None = None, + delegation_id: str | None = None, + threads: int = 1, + ) -> None: + super().__init__() self.ce = ce self.delegation_id = delegation_id self.threads = threads - def submit(self, job_file, ce=None, delegation_id=None, retries=0, retry_delay=3, silent=False): + def submit( # type: ignore[override] + self, + job_file: str | pathlib.Path, + ce: str | None = None, + delegation_id: str | None = None, + retries: int = 0, + retry_delay: float | int = 3, + silent: bool = False, + ) -> str | None: # default arguments if ce is None: ce = self.ce @@ -55,45 +72,52 @@ def submit(self, job_file, ce=None, delegation_id=None, retries=0, retry_delay=3 raise ValueError("ce must not be empty") # prepare round robin for ces and delegations - ce = make_list(ce) - if delegation_id: - delegation_id = make_list(delegation_id) - if len(ce) != len(delegation_id): - raise Exception("numbers of CEs ({}) and delegation ids ({}) do not match".format( - len(ce), len(delegation_id))) + _ce = make_list(ce) + _delegation_id = make_list(delegation_id) if delegation_id else None + if _delegation_id: + if len(_ce) != len(_delegation_id): + raise Exception( + f"numbers of CEs ({len(_ce)}) and delegation ids ({len(_delegation_id)}) " + "do not match", + ) # get the job file location as the submission command is run it the same directory - job_file_dir, job_file_name = os.path.split(os.path.abspath(str(job_file))) + job_file_dir, job_file_name = os.path.split(os.path.abspath(get_path(job_file))) # define the actual submission in a loop to simplify retries while True: # build the command - i = random.randint(0, len(ce) - 1) - cmd = ["glite-ce-job-submit", "-r", ce[i]] - if delegation_id: - cmd += ["-D", delegation_id[i]] + i = random.randint(0, len(_ce) - 1) + cmd = ["glite-ce-job-submit", "-r", _ce[i]] + if _delegation_id: + cmd += ["-D", _delegation_id[i]] cmd += [job_file_name] - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run the command # glite prints everything to stdout - logger.debug("submit glite job with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, cwd=job_file_dir) + logger.debug(f"submit glite job with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + cwd=job_file_dir, + ) # in some cases, the return code is 0 but the ce did not respond with a valid id if code == 0: job_id = out.strip().split("\n")[-1].strip() if not self.submission_job_id_cre.match(job_id): code = 1 - out = "bad job id '{}' from output:\n{}".format(job_id, out) + out = f"bad job id '{job_id}' from output:\n{out}" # retry or done? if code == 0: return job_id - logger.debug("submission of glite job '{}' failed with code {}:\n{}".format( - job_file, code, out)) + logger.debug(f"submission of glite job '{job_file}' failed with code {code}:\n{out}") if retries > 0: retries -= 1 @@ -103,71 +127,98 @@ def submit(self, job_file, ce=None, delegation_id=None, retries=0, retry_delay=3 if silent: return None - raise Exception("submission of glite job '{}' failed:\n{}".format(job_file, out)) + raise Exception(f"submission of glite job '{job_file}' failed:\n{out}") - def cancel(self, job_id, silent=False): + def cancel( # type: ignore[override] + self, + job_id: str | Sequence[str], + silent: bool = False, + ) -> dict[str, Any] | None: chunking = isinstance(job_id, (list, tuple)) job_ids = make_list(job_id) # build the command cmd = ["glite-ce-job-cancel", "-N"] + job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("cancel glite job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE) + logger.debug(f"cancel glite job(s) with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + ) # check success if code != 0 and not silent: # glite prints everything to stdout - raise Exception("cancellation of glite job(s) '{}' failed with code {}:\n{}".format( - job_id, code, out)) + raise Exception( + f"cancellation of glite job(s) '{job_id}' failed with code {code}:\n{out}", + ) return {job_id: None for job_id in job_ids} if chunking else None - def cleanup(self, job_id, silent=False): + def cleanup( # type: ignore[override] + self, + job_id: str | Sequence[str], + silent: bool = False, + ) -> dict[str, Any] | None: chunking = isinstance(job_id, (list, tuple)) job_ids = make_list(job_id) # build the command cmd = ["glite-ce-job-purge", "-N"] + job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("cleanup glite job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE) + logger.debug(f"cleanup glite job(s) with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + ) # check success if code != 0 and not silent: # glite prints everything to stdout - raise Exception("cleanup of glite job(s) '{}' failed with code {}:\n{}".format( - job_id, code, out)) + raise Exception(f"cleanup of glite job(s) '{job_id}' failed with code {code}:\n{out}") return {job_id: None for job_id in job_ids} if chunking else None - def query(self, job_id, silent=False): + def query( # type: ignore[override] + self, + job_id: str | Sequence[str], + silent: bool = False, + ) -> dict[int, dict[str, Any]] | dict[str, Any] | None: chunking = isinstance(job_id, (list, tuple)) job_ids = make_list(job_id) # build the command cmd = ["glite-ce-job-status", "-n", "-L", "0"] + job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("query glite job(s) with command '{}'".format(cmd)) - code, out, _ = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE) + logger.debug(f"query glite job(s) with command '{cmd_str}'") + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + ) # handle errors if code != 0: if silent: return None - else: - # glite prints everything to stdout - raise Exception("status query of glite job(s) '{}' failed with code {}:\n{}".format( - job_id, code, out)) + # glite prints everything to stdout + raise Exception( + f"status query of glite job(s) '{job_id}' failed with code {code}:\n{out}", + ) # parse the output and extract the status per job query_data = self.parse_query_output(out) @@ -178,21 +229,22 @@ def query(self, job_id, silent=False): if not chunking: if silent: return None - else: - raise Exception("glite job(s) '{}' not found in query response".format( - job_id)) + raise Exception(f"glite job(s) '{job_id}' not found in query response") else: - query_data[_job_id] = self.job_status_dict(job_id=_job_id, status=self.FAILED, - error="job not found in query response") + query_data[_job_id] = self.job_status_dict( + job_id=_job_id, + status=self.FAILED, + error="job not found in query response", + ) - return query_data if chunking else query_data[job_id] + return query_data if chunking else query_data[job_id] # type: ignore[index] @classmethod - def parse_query_output(cls, out): + def parse_query_output(cls, out: str) -> dict[int, dict[str, Any]]: # blocks per job are separated by ****** blocks = [] - for block in out.split("******"): - block = dict(cls.status_block_cre.findall(block)) + for block_str in out.split("******"): + block = dict(cls.status_block_cre.findall(block_str)) if block: blocks.append(block) @@ -240,19 +292,20 @@ def parse_query_output(cls, out): return query_data @classmethod - def map_status(cls, status): + def map_status(cls, status: str | None) -> str: # see https://wiki.italiangrid.it/twiki/bin/view/CREAM/UserGuide#4_CREAM_job_states if status in ("REGISTERED", "PENDING", "IDLE", "HELD"): return cls.PENDING - elif status in ("RUNNING", "REALLY-RUNNING"): + if status in ("RUNNING", "REALLY-RUNNING"): return cls.RUNNING - elif status in ("DONE-OK",): + if status in ("DONE-OK",): return cls.FINISHED - elif status in ("CANCELLED", "DONE-FAILED", "ABORTED"): - return cls.FAILED - else: + if status in ("CANCELLED", "DONE-FAILED", "ABORTED"): return cls.FAILED + logger.debug(f"unknown glite job state '{status}'") + return cls.FAILED + class GLiteJobFileFactory(BaseJobFileFactory): @@ -262,29 +315,49 @@ class GLiteJobFileFactory(BaseJobFileFactory): "absolute_paths", ] - def __init__(self, file_name="glite_job.jdl", command=None, executable=None, arguments=None, - input_files=None, output_files=None, postfix_output_files=True, output_uri=None, - stdout="stdout.txt", stderr="stderr.txt", vo=None, custom_content=None, - absolute_paths=False, **kwargs): + def __init__( + self, + *, + file_name: str = "glite_job.jdl", + command: str | Sequence[str] | None = None, + executable: str | None = None, + arguments: str | Sequence[str] | None = None, + input_files: dict[str, str | pathlib.Path | JobInputFile] | None = None, + output_files: list[str] | None = None, + postfix_output_files: bool = True, + output_uri: str | None = None, + stdout: str = "stdout.txt", + stderr: str = "stderr.txt", + vo: str | None = None, + custom_content: str | Sequence[str] | None = None, + absolute_paths: bool = False, + **kwargs, + ) -> None: # get some default kwargs from the config cfg = Config.instance() if kwargs.get("dir") is None: - kwargs["dir"] = cfg.get_expanded("job", cfg.find_option("job", - "glite_job_file_dir", "job_file_dir")) + kwargs["dir"] = cfg.get_expanded( + "job", + cfg.find_option("job", "glite_job_file_dir", "job_file_dir"), + ) if kwargs.get("mkdtemp") is None: - kwargs["mkdtemp"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "glite_job_file_dir_mkdtemp", "job_file_dir_mkdtemp")) + kwargs["mkdtemp"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "glite_job_file_dir_mkdtemp", "job_file_dir_mkdtemp"), + ) if kwargs.get("cleanup") is None: - kwargs["cleanup"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "glite_job_file_dir_cleanup", "job_file_dir_cleanup")) + kwargs["cleanup"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "glite_job_file_dir_cleanup", "job_file_dir_cleanup"), + ) - super(GLiteJobFileFactory, self).__init__(**kwargs) + super().__init__(**kwargs) self.file_name = file_name self.command = command self.executable = executable self.arguments = arguments - self.input_files = DeprecatedInputFiles(input_files or {}) + self.input_files = input_files or {} self.output_files = output_files or [] self.postfix_output_files = postfix_output_files self.output_uri = output_uri @@ -294,7 +367,11 @@ def __init__(self, file_name="glite_job.jdl", command=None, executable=None, arg self.custom_content = custom_content self.absolute_paths = absolute_paths - def create(self, postfix=None, render_variables=None, **kwargs): + def create( + self, + postfix: str | None = None, + **kwargs, + ) -> tuple[str, GLiteJobFileFactory.Config]: # merge kwargs and instance attributes c = self.get_config(**kwargs) @@ -468,17 +545,17 @@ def prepare_input(f): with open(job_file, "w") as f: f.write("[\n") for key, value in content: - f.write(self.create_line(key, value) + "\n") + f.write(f"{self.create_line(key, value)}\n") f.write("]\n") - logger.debug("created glite job file at '{}'".format(job_file)) + logger.debug(f"created glite job file at '{job_file}'") return job_file, c @classmethod - def create_line(cls, key, value): + def create_line(cls, key: str, value: Any) -> str: if isinstance(value, (list, tuple)): - value = "{{{}}}".format(", ".join("\"{}\"".format(v) for v in value)) + value = "{{{}}}".format(", ".join(f"\"{v}\"" for v in value)) else: - value = "\"{}\"".format(value) - return "{} = {};".format(key, value) + value = f"\"{value}\"" + return f"{key} = {value};".format(key, value) diff --git a/law/contrib/glite/workflow.py b/law/contrib/glite/workflow.py index d38963ae..b45b6345 100644 --- a/law/contrib/glite/workflow.py +++ b/law/contrib/glite/workflow.py @@ -5,23 +5,27 @@ https://wiki.italiangrid.it/twiki/bin/view/CREAM/UserGuide. """ -__all__ = ["GLiteWorkflow"] +from __future__ import annotations +__all__ = ["GLiteWorkflow"] import os import sys -from abc import abstractmethod -from collections import OrderedDict +import abc +import pathlib import law from law.workflow.remote import BaseRemoteWorkflow, BaseRemoteWorkflowProxy -from law.job.base import JobArguments, JobInputFile, DeprecatedInputFiles +from law.job.base import JobArguments, JobInputFile from law.task.proxy import ProxyCommand from law.target.file import get_path +from law.target.local import LocalFileTarget from law.parameter import CSVParameter -from law.util import law_src_path, merge_dicts, DotDict +from law.util import law_src_path, merge_dicts, DotDict, InsertableDict from law.logger import get_logger +from law._types import Type, Any +from law.contrib.wlcg import WLCGDirectoryTarget, delegate_vomsproxy_glite from law.contrib.glite.job import GLiteJobManager, GLiteJobFileFactory @@ -30,10 +34,10 @@ class GLiteWorkflowProxy(BaseRemoteWorkflowProxy): - workflow_type = "glite" + workflow_type: str = "glite" - def __init__(self, *args, **kwargs): - super(GLiteWorkflowProxy, self).__init__(*args, **kwargs) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # check if there is at least one ce if not self.task.glite_ce: @@ -41,34 +45,38 @@ def __init__(self, *args, **kwargs): self.delegation_ids = None - def create_job_manager(self, **kwargs): + def create_job_manager(self, **kwargs) -> GLiteJobManager: return self.task.glite_create_job_manager(**kwargs) - def setup_job_mananger(self): + def setup_job_mananger(self) -> dict[str, Any]: kwargs = {} # delegate the voms proxy to all endpoints if callable(self.task.glite_delegate_proxy): delegation_ids = [] for ce in self.task.glite_ce: - endpoint = law.wlcg.get_ce_endpoint(ce) + endpoint = law.wlcg.get_ce_endpoint(ce) # type: ignore[attr-defined] delegation_ids.append(self.task.glite_delegate_proxy(endpoint)) kwargs["delegation_id"] = delegation_ids return kwargs - def create_job_file_factory(self, **kwargs): + def create_job_file_factory(self, **kwargs) -> GLiteJobFileFactory: return self.task.glite_create_job_file_factory(**kwargs) - def create_job_file(self, job_num, branches): + def create_job_file( + self, + job_num: int, + branches: list[int], + ) -> dict[str, str | pathlib.Path | GLiteJobFileFactory.Config | None]: task = self.task # the file postfix is pythonic range made from branches, e.g. [0, 1, 2, 4] -> "_0To5" - postfix = "_{}To{}".format(branches[0], branches[-1] + 1) + postfix = f"_{branches[0]}To{branches[-1] + 1}" # create the config - c = self.job_file_factory.get_config() - c.input_files = DeprecatedInputFiles() + c = self.job_file_factory.get_config() # type: ignore[union-attr] + c.input_files = {} c.output_files = [] c.render_variables = {} c.custom_content = [] @@ -98,18 +106,23 @@ def create_job_file(self, job_num, branches): ) if task.glite_use_local_scheduler(): proxy_cmd.add_arg("--local-scheduler", "True", overwrite=True) - for key, value in OrderedDict(task.glite_cmdline_args()).items(): + for key, value in dict(task.glite_cmdline_args()).items(): proxy_cmd.add_arg(key, value, overwrite=True) # job script arguments + dashboard_data = None + if self.dashboard is not None: + dashboard_data = self.dashboard.remote_hook_data( + job_num, + self.job_data.attempts.get(job_num, 0), + ) job_args = JobArguments( task_cls=task.__class__, task_params=proxy_cmd.build(skip_run=True), branches=branches, workers=task.job_workers, auto_retry=False, - dashboard_data=self.dashboard.remote_hook_data( - job_num, self.job_data.attempts.get(job_num, 0)), + dashboard_data=dashboard_data, ) c.arguments = job_args.join() @@ -124,9 +137,10 @@ def create_job_file(self, job_num, branches): c.input_files["stageout_file"] = stageout_file # does the dashboard have a hook file? - dashboard_file = self.dashboard.remote_hook_file() - if dashboard_file: - c.input_files["dashboard_file"] = dashboard_file + if self.dashboard is not None: + dashboard_file = self.dashboard.remote_hook_file() + if dashboard_file: + c.input_files["dashboard_file"] = dashboard_file # log file c.stdout = None @@ -144,7 +158,7 @@ def create_job_file(self, job_num, branches): c = task.glite_job_config(c, job_num, branches) # build the job file and get the sanitized config - job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) + job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) # type: ignore[misc] # determine the custom log file uri if set abs_log_file = None @@ -154,10 +168,10 @@ def create_job_file(self, job_num, branches): # return job and log files return {"job": job_file, "config": c, "log": abs_log_file} - def destination_info(self): - info = super(GLiteWorkflowProxy, self).destination_info() + def destination_info(self) -> InsertableDict: + info = super().destination_info() - info["ce"] = "ce: {}".format(",".join(self.task.glite_ce)) + info["ce"] = f"ce: {','.join(self.task.glite_ce)}" info = self.task.glite_destination_info(info) @@ -178,7 +192,7 @@ class GLiteWorkflow(BaseRemoteWorkflow): description="target glite computing element(s); default: empty", ) - glite_job_kwargs = [] + glite_job_kwargs: list[str] = [] glite_job_kwargs_submit = ["glite_ce"] glite_job_kwargs_cancel = None glite_job_kwargs_cleanup = None @@ -186,51 +200,55 @@ class GLiteWorkflow(BaseRemoteWorkflow): exclude_params_branch = {"glite_ce"} - exclude_params_glite_workflow = set() + exclude_params_glite_workflow: set[str] = set() exclude_index = True - @abstractmethod - def glite_output_directory(self): - return None + @abc.abstractmethod + def glite_output_directory(self) -> WLCGDirectoryTarget: + ... - @abstractmethod - def glite_bootstrap_file(self): - return None + @abc.abstractmethod + def glite_bootstrap_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile: + ... - def glite_wrapper_file(self): + def glite_wrapper_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def glite_job_file(self): + def glite_job_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile: return JobInputFile(law_src_path("job", "law_job.sh")) - def glite_stageout_file(self): + def glite_stageout_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def glite_workflow_requires(self): + def glite_workflow_requires(self) -> DotDict: return DotDict() - def glite_output_postfix(self): + def glite_output_postfix(self) -> str: return "" - def glite_output_uri(self): - return self.glite_output_directory().uri() + def glite_output_uri(self) -> str: + return self.glite_output_directory().uri(return_all=False) # type: ignore[return-value] - def glite_delegate_proxy(self, endpoint): - return law.wlcg.delegate_vomsproxy_glite(endpoint, stdout=sys.stdout, stderr=sys.stderr, - cache=True) + def glite_delegate_proxy(self, endpoint: str) -> str: + return delegate_vomsproxy_glite( # type: ignore[attr-defined] + endpoint, + stdout=sys.stdout, + stderr=sys.stderr, + cache=True, + ) - def glite_job_manager_cls(self): + def glite_job_manager_cls(self) -> Type[GLiteJobManager]: return GLiteJobManager - def glite_create_job_manager(self, **kwargs): + def glite_create_job_manager(self, **kwargs) -> GLiteJobManager: kwargs = merge_dicts(self.glite_job_manager_defaults, kwargs) return self.glite_job_manager_cls()(**kwargs) - def glite_job_file_factory_cls(self): + def glite_job_file_factory_cls(self) -> Type[GLiteJobFileFactory]: return GLiteJobFileFactory - def glite_create_job_file_factory(self, **kwargs): + def glite_create_job_file_factory(self, **kwargs) -> GLiteJobFileFactory: # job file fectory config priority: kwargs > class defaults kwargs = merge_dicts({}, self.glite_job_file_factory_defaults, kwargs) return self.glite_job_file_factory_cls()(**kwargs) @@ -238,17 +256,17 @@ def glite_create_job_file_factory(self, **kwargs): def glite_job_config(self, config, job_num, branches): return config - def glite_check_job_completeness(self): + def glite_check_job_completeness(self) -> bool: return False - def glite_check_job_completeness_delay(self): + def glite_check_job_completeness_delay(self) -> float | int: return 0.0 - def glite_use_local_scheduler(self): + def glite_use_local_scheduler(self) -> bool: return True - def glite_cmdline_args(self): + def glite_cmdline_args(self) -> dict[str, str]: return {} - def glite_destination_info(self, info): + def glite_destination_info(self, info: InsertableDict) -> InsertableDict: return info diff --git a/law/contrib/htcondor/__init__.py b/law/contrib/htcondor/__init__.py index c72d57c3..5ba07d21 100644 --- a/law/contrib/htcondor/__init__.py +++ b/law/contrib/htcondor/__init__.py @@ -11,7 +11,6 @@ "HTCondorWorkflow", ] - # provisioning imports from law.contrib.htcondor.util import get_htcondor_version from law.contrib.htcondor.job import HTCondorJobManager, HTCondorJobFileFactory diff --git a/law/contrib/htcondor/job.py b/law/contrib/htcondor/job.py index 2ed3df19..4c57f0e0 100644 --- a/law/contrib/htcondor/job.py +++ b/law/contrib/htcondor/job.py @@ -4,21 +4,24 @@ HTCondor job manager. See https://research.cs.wisc.edu/htcondor. """ -__all__ = ["HTCondorJobManager", "HTCondorJobFileFactory"] +from __future__ import annotations +__all__ = ["HTCondorJobManager", "HTCondorJobFileFactory"] import os import stat import time import re +import pathlib import tempfile import subprocess from law.config import Config -from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile, DeprecatedInputFiles +from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile from law.target.file import get_path from law.util import interruptable_popen, make_list, make_unique, quote_cmd from law.logger import get_logger +from law._types import Any, Sequence from law.contrib.htcondor.util import get_htcondor_version @@ -45,8 +48,14 @@ class HTCondorJobManager(BaseJobManager): submission_job_id_cre = re.compile(r"^(\d+) job\(s\) submitted to cluster (\d+)\.$") long_block_cre = re.compile(r"(\w+) \= \"?([^\"\n]*)\"?\n") - def __init__(self, pool=None, scheduler=None, user=None, threads=1): - super(HTCondorJobManager, self).__init__() + def __init__( + self, + pool: str | None = None, + scheduler: str | None = None, + user: str | None = None, + threads: int = 1, + ) -> None: + super().__init__() self.pool = pool self.scheduler = scheduler @@ -60,13 +69,21 @@ def __init__(self, pool=None, scheduler=None, user=None, threads=1): self.htcondor_ge_v833 = self.htcondor_version and self.htcondor_version >= (8, 3, 3) self.htcondor_ge_v856 = self.htcondor_version and self.htcondor_version >= (8, 5, 6) - def cleanup(self, *args, **kwargs): + def cleanup(self, *args, **kwargs) -> None: # type: ignore[override] raise NotImplementedError("HTCondorJobManager.cleanup is not implemented") - def cleanup_batch(self, *args, **kwargs): + def cleanup_batch(self, *args, **kwargs) -> None: # type: ignore[override] raise NotImplementedError("HTCondorJobManager.cleanup_batch is not implemented") - def submit(self, job_file, pool=None, scheduler=None, retries=0, retry_delay=3, silent=False): + def submit( # type: ignore[override] + self, + job_file: str | pathlib.Path | Sequence[str | pathlib.Path], + pool: str | None = None, + scheduler: str | None = None, + retries: int = 0, + retry_delay: float | int = 3, + silent: bool = False, + ) -> str | Sequence[str] | None: # default arguments if pool is None: pool = self.pool @@ -84,7 +101,7 @@ def has_initialdir(job_file): return False chunking = isinstance(job_file, (list, tuple)) - job_files = list(map(str, make_list(job_file))) + job_files = list(map(get_path, make_list(job_file))) job_file_dir = None for i, job_file in enumerate(job_files): dirname, basename = os.path.split(job_file) @@ -94,10 +111,9 @@ def has_initialdir(job_file): elif dirname != job_file_dir: if not has_initialdir(job_file): raise Exception( - "cannot performed chunked submission as job file '{}' is not located in a " - "previously seen directory '{}' and has no initialdir".format( - job_file, job_file_dir, - ), + f"cannot performed chunked submission as job file '{job_file}' is not " + f"located in a previously seen directory '{job_file_dir}' and has no " + "initialdir", ) # define a single, merged job file if necessary @@ -106,7 +122,7 @@ def has_initialdir(job_file): with open(_job_file, "w") as f: for job_file in job_files: with open(job_file, "r") as _f: - f.write(_f.read() + "\n") + f.write(f"{_f.read()}\n") job_files = [_job_file] # build the command @@ -116,14 +132,22 @@ def has_initialdir(job_file): if scheduler: cmd += ["-name", scheduler] cmd += list(map(os.path.basename, job_files)) - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # define the actual submission in a loop to simplify retries while True: # run the command - logger.debug("submit htcondor job with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=os.path.dirname(job_files[0])) + logger.debug(f"submit htcondor job with command '{cmd_str}'") + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=os.path.dirname(job_files[0]), + ) # get the job id(s) if code == 0: @@ -132,21 +156,19 @@ def has_initialdir(job_file): for line in out.strip().split("\n"): m = self.submission_job_id_cre.match(line.strip()) if m: - job_ids.extend([ - "{}.{}".format(m.group(2), i) - for i in range(int(m.group(1))) - ]) + job_ids.extend([f"{m.group(2)}.{i}" for i in range(int(m.group(1)))]) if not job_ids: code = 1 - err = "cannot parse htcondor job id(s) from output:\n{}".format(out) + err = f"cannot parse htcondor job id(s) from output:\n{out}" # retry or done? if code == 0: return job_ids if chunking else job_ids[0] job_files_repr = ",".join(map(os.path.basename, job_files)) - logger.debug("submission of htcondor job(s) '{}' failed with code {}:\n{}".format( - job_files_repr, code, err)) + logger.debug( + f"submission of htcondor job(s) '{job_files_repr}' failed with code {code}:\n{err}", + ) if retries > 0: retries -= 1 @@ -156,10 +178,15 @@ def has_initialdir(job_file): if silent: return None - raise Exception("submission of htcondor job(s) '{}' failed:\n{}".format( - job_files_repr, err)) + raise Exception(f"submission of htcondor job(s) '{job_files_repr}' failed:\n{err}") - def cancel(self, job_id, pool=None, scheduler=None, silent=False): + def cancel( # type: ignore[override] + self, + job_id: str | Sequence[str], + pool: str | None = None, + scheduler: str | None = None, + silent: bool = False, + ) -> dict[str, None] | None: # default arguments if pool is None: pool = self.pool @@ -176,21 +203,36 @@ def cancel(self, job_id, pool=None, scheduler=None, silent=False): if scheduler: cmd += ["-name", scheduler] cmd += job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("cancel htcondor job(s) with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + logger.debug(f"cancel htcondor job(s) with command '{cmd_str}'") + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # check success if code != 0 and not silent: - raise Exception("cancellation of htcondor job(s) '{}' failed with code {}:\n{}".format( - job_id, code, err)) + raise Exception( + f"cancellation of htcondor job(s) '{job_id}' failed with code {code}:\n{err}", + ) return {job_id: None for job_id in job_ids} if chunking else None - def query(self, job_id, pool=None, scheduler=None, user=None, silent=False): + def query( # type: ignore[override] + self, + job_id: str | Sequence[str], + pool: str | None = None, + scheduler: str | None = None, + user: str | None = None, + silent: bool = False, + ) -> dict[int, dict[str, Any]] | dict[str, Any] | None: # default arguments if pool is None: pool = self.pool @@ -218,18 +260,26 @@ def query(self, job_id, pool=None, scheduler=None, user=None, silent=False): # since v8.5.6 one can define the attributes to fetch if self.htcondor_ge_v856: cmd += ["-attributes", ads] - cmd = quote_cmd(cmd) - - logger.debug("query htcondor job(s) with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + cmd_str = quote_cmd(cmd) + + logger.debug(f"query htcondor job(s) with command '{cmd_str}'") + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # handle errors if code != 0: if silent: return None - raise Exception("queue query of htcondor job(s) '{}' failed with code {}:" - "\n{}".format(job_id, code, err)) + raise Exception( + f"queue query of htcondor job(s) '{job_id}' failed with code {code}:\n{err}", + ) # parse the output and extract the status per job query_data = self.parse_long_output(out) @@ -250,18 +300,24 @@ def query(self, job_id, pool=None, scheduler=None, user=None, silent=False): # since v8.5.6 one can define the attributes to fetch if self.htcondor_ge_v856: cmd += ["-attributes", ads] - cmd = quote_cmd(cmd) - - logger.debug("query htcondor job history with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + cmd_str = quote_cmd(cmd) + + logger.debug(f"query htcondor job history with command '{cmd_str}'") + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # handle errors if code != 0: if silent: return None - raise Exception("history query of htcondor job(s) '{}' failed with code {}:" - "\n{}".format(job_id, code, err)) + raise Exception( + f"history query of htcondor job(s) '{job_id}' failed with code {code}:\n{err}", + ) # parse the output and update query data query_data.update(self.parse_long_output(out)) @@ -272,16 +328,18 @@ def query(self, job_id, pool=None, scheduler=None, user=None, silent=False): if not chunking: if silent: return None - raise Exception("htcondor job(s) '{}' not found in query response".format( - job_id)) + raise Exception(f"htcondor job(s) '{job_id}' not found in query response") - query_data[_job_id] = self.job_status_dict(job_id=_job_id, status=self.FAILED, - error="job not found in query response") + query_data[_job_id] = self.job_status_dict( + job_id=_job_id, + status=self.FAILED, + error="job not found in query response", + ) - return query_data if chunking else query_data[job_id] + return query_data if chunking else query_data[job_id] # type: ignore[index] @classmethod - def parse_long_output(cls, out): + def parse_long_output(cls, out: str) -> dict[str, dict[str, Any]]: # retrieve information per block mapped to the job id query_data = {} for block in out.strip().split("\n\n"): @@ -316,29 +374,33 @@ def parse_long_output(cls, out): if status != cls.FAILED: status = cls.FAILED if not error: - error = "job status set to '{}' due to non-zero exit code {}".format( - cls.FAILED, code) + error = f"job status set to '{cls.FAILED}' due to non-zero exit code {code}" # store it - query_data[job_id] = cls.job_status_dict(job_id=job_id, status=status, code=code, - error=error) + query_data[job_id] = cls.job_status_dict( + job_id=job_id, + status=status, + code=code, + error=error, + ) return query_data @classmethod - def map_status(cls, status_flag): + def map_status(cls, status_flag: str | None) -> str: # see http://pages.cs.wisc.edu/~adesmet/status.html if status_flag in ("0", "1", "U", "I"): return cls.PENDING - elif status_flag in ("2", "R"): + if status_flag in ("2", "R"): return cls.RUNNING - elif status_flag in ("4", "C"): + if status_flag in ("4", "C"): return cls.FINISHED - elif status_flag in ("5", "6", "H", "E"): - return cls.FAILED - else: + if status_flag in ("5", "6", "H", "E"): return cls.FAILED + logger.debug(f"unknown htcondor job state '{status_flag}'") + return cls.FAILED + class HTCondorJobFileFactory(BaseJobFileFactory): @@ -348,29 +410,49 @@ class HTCondorJobFileFactory(BaseJobFileFactory): "absolute_paths", ] - def __init__(self, file_name="htcondor_job.jdl", command=None, executable=None, arguments=None, - input_files=None, output_files=None, log="log.txt", stdout="stdout.txt", - stderr="stderr.txt", postfix_output_files=True, universe="vanilla", - notification="Never", custom_content=None, absolute_paths=False, **kwargs): + def __init__( + self, + file_name: str = "htcondor_job.jdl", + command: str | Sequence[str] | None = None, + executable: str | None = None, + arguments: str | Sequence[str] | None = None, + input_files: dict[str, str | pathlib.Path | JobInputFile] | None = None, + output_files: Sequence[str] | None = None, + log: str = "log.txt", + stdout: str = "stdout.txt", + stderr: str = "stderr.txt", + postfix_output_files: bool = True, + universe: str = "vanilla", + notification: str = "Never", + custom_content: str | Sequence[str] | None = None, + absolute_paths: bool = False, + **kwargs, + ) -> None: # get some default kwargs from the config cfg = Config.instance() if kwargs.get("dir") is None: - kwargs["dir"] = cfg.get_expanded("job", cfg.find_option("job", - "htcondor_job_file_dir", "job_file_dir")) + kwargs["dir"] = cfg.get_expanded( + "job", + cfg.find_option("job", "htcondor_job_file_dir", "job_file_dir"), + ) if kwargs.get("mkdtemp") is None: - kwargs["mkdtemp"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "htcondor_job_file_dir_mkdtemp", "job_file_dir_mkdtemp")) + kwargs["mkdtemp"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "htcondor_job_file_dir_mkdtemp", "job_file_dir_mkdtemp"), + ) if kwargs.get("cleanup") is None: - kwargs["cleanup"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "htcondor_job_file_dir_cleanup", "job_file_dir_cleanup")) + kwargs["cleanup"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "htcondor_job_file_dir_cleanup", "job_file_dir_cleanup"), + ) - super(HTCondorJobFileFactory, self).__init__(**kwargs) + super().__init__(**kwargs) self.file_name = file_name self.command = command self.executable = executable self.arguments = arguments - self.input_files = DeprecatedInputFiles(input_files or {}) + self.input_files = input_files or {} self.output_files = output_files or [] self.log = log self.stdout = stdout @@ -381,7 +463,11 @@ def __init__(self, file_name="htcondor_job.jdl", command=None, executable=None, self.custom_content = custom_content self.absolute_paths = absolute_paths - def create(self, postfix=None, **kwargs): + def create( + self, + postfix: str | None = None, + **kwargs, + ) -> tuple[str, HTCondorJobFileFactory.Config]: # merge kwargs and instance attributes c = self.get_config(**kwargs) @@ -525,7 +611,7 @@ def prepare_input(f): os.chmod(path, os.stat(path).st_mode | stat.S_IXUSR | stat.S_IXGRP) # job file content - content = [] + content: list[str | tuple[str, Any]] = [] content.append(("universe", c.universe)) if c.command: cmd = quote_cmd(c.command) if isinstance(c.command, (list, tuple)) else c.command @@ -568,17 +654,16 @@ def prepare_input(f): with open(job_file, "w") as f: for obj in content: line = self.create_line(*make_list(obj)) - f.write(line + "\n") + f.write(f"{line}\n") - logger.debug("created htcondor job file at '{}'".format(job_file)) + logger.debug(f"created htcondor job file at '{job_file}'") return job_file, c @classmethod - def create_line(cls, key, value=None): + def create_line(cls, key: str, value: Any | None = None) -> str: if isinstance(value, (list, tuple)): value = ",".join(str(v) for v in value) if value is None: return str(key) - else: - return "{} = {}".format(key, value) + return f"{key} = {value}" diff --git a/law/contrib/htcondor/util.py b/law/contrib/htcondor/util.py index ec562391..af415e35 100644 --- a/law/contrib/htcondor/util.py +++ b/law/contrib/htcondor/util.py @@ -4,38 +4,47 @@ HTCondor utilities. """ -__all__ = ["get_htcondor_version"] +from __future__ import annotations +__all__ = ["get_htcondor_version"] import re import subprocess import threading -from law.util import no_value, interruptable_popen +from law.util import NoValue, no_value, interruptable_popen -_htcondor_version = no_value +_htcondor_version: tuple[int, int, int] | None | NoValue = no_value _htcondor_version_lock = threading.Lock() -def get_htcondor_version(): +def get_htcondor_version() -> tuple[int, int, int] | None: """ Returns the version of the HTCondor installation in a 3-tuple. The value is cached to accelerate - repeated function invocations. + repeated function invocations. When the ``condor_version`` executable is not available, *None* + is returned. """ global _htcondor_version if _htcondor_version == no_value: version = None with _htcondor_version_lock: - code, out, _ = interruptable_popen("condor_version", shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + "condor_version", + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if code == 0: first_line = out.strip().split("\n")[0] m = re.match(r"^\$CondorVersion: (\d+)\.(\d+)\.(\d+) .+$", first_line.strip()) if m: version = tuple(map(int, m.groups())) - _htcondor_version = version + _htcondor_version = version # type: ignore[assignment] - return _htcondor_version + return _htcondor_version # type: ignore[return-value] diff --git a/law/contrib/htcondor/workflow.py b/law/contrib/htcondor/workflow.py index b1ea51b4..f5be2c18 100644 --- a/law/contrib/htcondor/workflow.py +++ b/law/contrib/htcondor/workflow.py @@ -4,23 +4,25 @@ HTCondor workflow implementation. See https://research.cs.wisc.edu/htcondor. """ -__all__ = ["HTCondorWorkflow"] +from __future__ import annotations +__all__ = ["HTCondorWorkflow"] import os -from abc import abstractmethod -from collections import OrderedDict +import abc +import pathlib -import luigi +import luigi # type: ignore[import-untyped] from law.workflow.remote import BaseRemoteWorkflow, BaseRemoteWorkflowProxy -from law.job.base import JobArguments, JobInputFile, DeprecatedInputFiles +from law.job.base import JobArguments, JobInputFile from law.task.proxy import ProxyCommand from law.target.file import get_path, get_scheme, FileSystemDirectoryTarget -from law.target.local import LocalDirectoryTarget +from law.target.local import LocalDirectoryTarget, LocalFileTarget from law.parameter import NO_STR -from law.util import law_src_path, merge_dicts, DotDict +from law.util import law_src_path, merge_dicts, DotDict, InsertableDict from law.logger import get_logger +from law._types import Type from law.contrib.htcondor.job import HTCondorJobManager, HTCondorJobFileFactory @@ -30,23 +32,27 @@ class HTCondorWorkflowProxy(BaseRemoteWorkflowProxy): - workflow_type = "htcondor" + workflow_type: str = "htcondor" - def create_job_manager(self, **kwargs): + def create_job_manager(self, **kwargs) -> HTCondorJobManager: return self.task.htcondor_create_job_manager(**kwargs) - def create_job_file_factory(self, **kwargs): + def create_job_file_factory(self, **kwargs) -> HTCondorJobFileFactory: return self.task.htcondor_create_job_file_factory(**kwargs) - def create_job_file(self, job_num, branches): + def create_job_file( + self, + job_num: int, + branches: list[int], + ) -> dict[str, str | pathlib.Path | HTCondorJobFileFactory.Config | None]: task = self.task # the file postfix is pythonic range made from branches, e.g. [0, 1, 2, 4] -> "_0To5" - postfix = "_{}To{}".format(branches[0], branches[-1] + 1) + postfix = f"_{branches[0]}To{branches[-1] + 1}" # create the config - c = self.job_file_factory.get_config() - c.input_files = DeprecatedInputFiles() + c = self.job_file_factory.get_config() # type: ignore[union-attr] + c.input_files = {} c.output_files = [] c.render_variables = {} c.custom_content = [] @@ -76,18 +82,23 @@ def create_job_file(self, job_num, branches): ) if task.htcondor_use_local_scheduler(): proxy_cmd.add_arg("--local-scheduler", "True", overwrite=True) - for key, value in OrderedDict(task.htcondor_cmdline_args()).items(): + for key, value in dict(task.htcondor_cmdline_args()).items(): proxy_cmd.add_arg(key, value, overwrite=True) # job script arguments + dashboard_data = None + if self.dashboard is not None: + dashboard_data = self.dashboard.remote_hook_data( + job_num, + self.job_data.attempts.get(job_num, 0), + ) job_args = JobArguments( task_cls=task.__class__, task_params=proxy_cmd.build(skip_run=True), branches=branches, workers=task.job_workers, auto_retry=False, - dashboard_data=self.dashboard.remote_hook_data( - job_num, self.job_data.attempts.get(job_num, 0)), + dashboard_data=dashboard_data, ) c.arguments = job_args.join() @@ -102,9 +113,10 @@ def create_job_file(self, job_num, branches): c.input_files["stageout_file"] = stageout_file # does the dashboard have a hook file? - dashboard_file = self.dashboard.remote_hook_file() - if dashboard_file: - c.input_files["dashboard_file"] = dashboard_file + if self.dashboard is not None: + dashboard_file = self.dashboard.remote_hook_file() + if dashboard_file: + c.input_files["dashboard_file"] = dashboard_file # logging # we do not use htcondor's logging mechanism since it might require that the submission @@ -135,7 +147,7 @@ def create_job_file(self, job_num, branches): del c.output_files[:] # build the job file and get the sanitized config - job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) + job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) # type: ignore[misc] # get the location of the custom local log file if any abs_log_file = None @@ -145,14 +157,14 @@ def create_job_file(self, job_num, branches): # return job and log files return {"job": job_file, "config": c, "log": abs_log_file} - def destination_info(self): - info = super(HTCondorWorkflowProxy, self).destination_info() + def destination_info(self) -> InsertableDict: + info = super().destination_info() if self.task.htcondor_pool and self.task.htcondor_pool != NO_STR: - info["pool"] = "pool: {}".format(self.task.htcondor_pool) + info["pool"] = f"pool: {self.task.htcondor_pool}" if self.task.htcondor_scheduler and self.task.htcondor_scheduler != NO_STR: - info["scheduler"] = "scheduler: {}".format(self.task.htcondor_scheduler) + info["scheduler"] = f"scheduler: {self.task.htcondor_scheduler}" info = self.task.htcondor_destination_info(info) @@ -178,68 +190,73 @@ class HTCondorWorkflow(BaseRemoteWorkflow): description="target htcondor scheduler; default: empty", ) - htcondor_job_kwargs = ["htcondor_pool", "htcondor_scheduler"] + htcondor_job_kwargs: list[str] = ["htcondor_pool", "htcondor_scheduler"] htcondor_job_kwargs_submit = None htcondor_job_kwargs_cancel = None htcondor_job_kwargs_query = None exclude_params_branch = {"htcondor_pool", "htcondor_scheduler"} - exclude_params_htcondor_workflow = set() + exclude_params_htcondor_workflow: set[str] = set() exclude_index = True - @abstractmethod - def htcondor_output_directory(self): - return None + @abc.abstractmethod + def htcondor_output_directory(self) -> FileSystemDirectoryTarget: + ... - def htcondor_workflow_requires(self): + def htcondor_workflow_requires(self) -> DotDict: return DotDict() - def htcondor_bootstrap_file(self): + def htcondor_bootstrap_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def htcondor_wrapper_file(self): + def htcondor_wrapper_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def htcondor_job_file(self): + def htcondor_job_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile: return JobInputFile(law_src_path("job", "law_job.sh")) - def htcondor_stageout_file(self): + def htcondor_stageout_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def htcondor_output_postfix(self): + def htcondor_output_postfix(self) -> str: return "" - def htcondor_job_manager_cls(self): + def htcondor_job_manager_cls(self) -> Type[HTCondorJobManager]: return HTCondorJobManager - def htcondor_create_job_manager(self, **kwargs): + def htcondor_create_job_manager(self, **kwargs) -> HTCondorJobManager: kwargs = merge_dicts(self.htcondor_job_manager_defaults, kwargs) return self.htcondor_job_manager_cls()(**kwargs) - def htcondor_job_file_factory_cls(self): + def htcondor_job_file_factory_cls(self) -> Type[HTCondorJobFileFactory]: return HTCondorJobFileFactory - def htcondor_create_job_file_factory(self, **kwargs): + def htcondor_create_job_file_factory(self, **kwargs) -> HTCondorJobFileFactory: # job file fectory config priority: kwargs > class defaults kwargs = merge_dicts({}, self.htcondor_job_file_factory_defaults, kwargs) return self.htcondor_job_file_factory_cls()(**kwargs) - def htcondor_job_config(self, config, job_num, branches): + def htcondor_job_config( + self, + config: HTCondorJobFileFactory.Config, + job_num: int, + branches: list[int], + ) -> HTCondorJobFileFactory.Config: return config - def htcondor_check_job_completeness(self): + def htcondor_check_job_completeness(self) -> bool: return False - def htcondor_check_job_completeness_delay(self): + def htcondor_check_job_completeness_delay(self) -> float | int: return 0.0 - def htcondor_use_local_scheduler(self): + def htcondor_use_local_scheduler(self) -> bool: return False - def htcondor_cmdline_args(self): + def htcondor_cmdline_args(self) -> dict[str, str]: return {} - def htcondor_destination_info(self, info): + def htcondor_destination_info(self, info: InsertableDict) -> InsertableDict: return info diff --git a/law/contrib/lsf/__init__.py b/law/contrib/lsf/__init__.py index 0a1059c2..20573f45 100644 --- a/law/contrib/lsf/__init__.py +++ b/law/contrib/lsf/__init__.py @@ -11,7 +11,6 @@ "LSFWorkflow", ] - # provisioning imports from law.contrib.lsf.util import get_lsf_version from law.contrib.lsf.job import LSFJobManager, LSFJobFileFactory diff --git a/law/contrib/lsf/job.py b/law/contrib/lsf/job.py index eceec96f..966f88de 100644 --- a/law/contrib/lsf/job.py +++ b/law/contrib/lsf/job.py @@ -4,22 +4,24 @@ LSF job manager. See https://www.ibm.com/support/knowledgecenter/en/SSETD4_9.1.3. """ -__all__ = ["LSFJobManager", "LSFJobFileFactory"] +from __future__ import annotations +__all__ = ["LSFJobManager", "LSFJobFileFactory"] import os import stat import time import re +import pathlib import subprocess -import six - from law.config import Config -from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile, DeprecatedInputFiles +from law.job.base import BaseJobManager, BaseJobFileFactory, JobInputFile from law.target.file import get_path +from law.target.local import LocalDirectoryTarget from law.util import interruptable_popen, make_list, make_unique, quote_cmd from law.logger import get_logger +from law._types import Any, Sequence from law.contrib.lsf.util import get_lsf_version @@ -38,8 +40,8 @@ class LSFJobManager(BaseJobManager): submission_job_id_cre = re.compile(r"^Job <(\d+)> is submitted.+$") - def __init__(self, queue=None, emails=False, threads=1): - super(LSFJobManager, self).__init__() + def __init__(self, queue: str | None = None, emails: bool = False, threads: int = 1) -> None: + super().__init__() self.queue = queue self.emails = emails @@ -51,13 +53,21 @@ def __init__(self, queue=None, emails=False, threads=1): # flags for versions with some important changes self.lsf_v912 = self.lsf_version and self.lsf_version >= (9, 1, 2) - def cleanup(self, *args, **kwargs): + def cleanup(self, *args, **kwargs) -> None: # type: ignore[override] raise NotImplementedError("LSFJobManager.cleanup is not implemented") - def cleanup_batch(self, *args, **kwargs): + def cleanup_batch(self, *args, **kwargs) -> None: # type: ignore[override] raise NotImplementedError("LSFJobManager.cleanup_batch is not implemented") - def submit(self, job_file, queue=None, emails=None, retries=0, retry_delay=3, silent=False): + def submit( # type: ignore[override] + self, + job_file: str | pathlib.Path, + queue: str | None = None, + emails: bool | None = None, + retries: int = 0, + retry_delay: float | int = 3, + silent: bool = False, + ) -> str | None: # default arguments if queue is None: queue = self.queue @@ -68,17 +78,25 @@ def submit(self, job_file, queue=None, emails=None, retries=0, retry_delay=3, si job_file_dir, job_file_name = os.path.split(os.path.abspath(str(job_file))) # build the command - cmd = "LSB_JOB_REPORT_MAIL={} bsub".format("Y" if emails else "N") + cmd = f"LSB_JOB_REPORT_MAIL={'Y' if emails else 'N'} bsub" if queue: - cmd += " -q {}".format(queue) - cmd += " < {}".format(job_file_name) + cmd += f" -q {queue}" + cmd += f" < {job_file_name}" # define the actual submission in a loop to simplify retries while True: # run the command - logger.debug("submit lsf job with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=job_file_dir) + logger.debug(f"submit lsf job with command '{cmd}'") + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=job_file_dir, + ) # get the job id if code == 0: @@ -87,14 +105,13 @@ def submit(self, job_file, queue=None, emails=None, retries=0, retry_delay=3, si job_id = m.group(1) else: code = 1 - err = "cannot parse job id from output:\n{}".format(out) + err = f"cannot parse job id from output:\n{out}" # retry or done? if code == 0: return job_id - logger.debug("submission of lsf job '{}' failed with code {}:\n{}".format( - job_file, code, err)) + logger.debug(f"submission of lsf job '{job_file}' failed with code {code}:\n{err}") if retries > 0: retries -= 1 @@ -104,9 +121,14 @@ def submit(self, job_file, queue=None, emails=None, retries=0, retry_delay=3, si if silent: return None - raise Exception("submission of lsf job '{}' failed: \n{}".format(job_file, err)) + raise Exception(f"submission of lsf job '{job_file}' failed: \n{err}") - def cancel(self, job_id, queue=None, silent=False): + def cancel( # type: ignore[override] + self, + job_id: str | Sequence[str], + queue: str | None = None, + silent: bool = False, + ) -> dict[int, None] | None: # default arguments if queue is None: queue = self.queue @@ -119,21 +141,30 @@ def cancel(self, job_id, queue=None, silent=False): if queue: cmd += ["-q", queue] cmd += job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("cancel lsf job(s) with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + logger.debug(f"cancel lsf job(s) with command '{cmd_str}'") + code, out, err = interruptable_popen( + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # check success if code != 0 and not silent: - raise Exception("cancellation of lsf job(s) '{}' failed with code {}:\n{}".format( - job_id, code, err)) + raise Exception(f"cancellation of lsf job(s) '{job_id}' failed with code {code}:\n{err}") return {job_id: None for job_id in job_ids} if chunking else None - def query(self, job_id, queue=None, silent=False): + def query( # type: ignore[override] + self, + job_id: str | Sequence[str], + queue: str | None = None, + silent: bool = False, + ) -> dict[int, dict[str, Any]] | dict[str, Any] | None: # default arguments if queue is None: queue = self.queue @@ -148,20 +179,25 @@ def query(self, job_id, queue=None, silent=False): if queue: cmd += ["-q", queue] cmd += job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("query lsf job(s) with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + logger.debug(f"query lsf job(s) with command '{cmd_str}'") + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # handle errors if code != 0: if silent: return None - else: - raise Exception("status query of lsf job(s) '{}' failed with code {}:\n{}".format( - job_id, code, err)) + raise Exception(f"status query of lsf job(s) '{job_id}' failed with code {code}:\n{err}") # parse the output and extract the status per job query_data = self.parse_query_output(out) @@ -172,17 +208,18 @@ def query(self, job_id, queue=None, silent=False): if not chunking: if silent: return None - else: - raise Exception("lsf job(s) '{}' not found in query response".format( - job_id)) + raise Exception(f"lsf job(s) '{job_id}' not found in query response") else: - query_data[_job_id] = self.job_status_dict(job_id=_job_id, status=self.FAILED, - error="job not found in query response") + query_data[_job_id] = self.job_status_dict( + job_id=_job_id, + status=self.FAILED, + error="job not found in query response", + ) - return query_data if chunking else query_data[job_id] + return query_data if chunking else query_data[job_id] # type: ignore[index] @classmethod - def parse_query_output(cls, out): + def parse_query_output(cls, out: str) -> dict[str, dict[str, Any]]: """ Example output to parse: 141914132 user_name DONE queue_name exec_host b63cee711a job_name Feb 8 14:54 @@ -206,19 +243,20 @@ def parse_query_output(cls, out): return query_data @classmethod - def map_status(cls, status_flag): + def map_status(cls, status_flag: str | None) -> str: # https://www.ibm.com/support/knowledgecenter/en/SSETD4_9.1.2/lsf_command_ref/bjobs.1.html if status_flag in ("PEND", "PROV", "PSUSP", "USUSP", "SSUSP", "WAIT"): return cls.PENDING - elif status_flag in ("RUN",): + if status_flag in ("RUN",): return cls.RUNNING - elif status_flag in ("DONE",): + if status_flag in ("DONE",): return cls.FINISHED - elif status_flag in ("EXIT", "UNKWN", "ZOMBI"): - return cls.FAILED - else: + if status_flag in ("EXIT", "UNKWN", "ZOMBI"): return cls.FAILED + logger.debug(f"unknown lsf job state '{status_flag}'") + return cls.FAILED + class LSFJobFileFactory(BaseJobFileFactory): @@ -228,24 +266,48 @@ class LSFJobFileFactory(BaseJobFileFactory): "stdout", "stderr", "shell", "emails", "custom_content", "absolute_paths", ] - def __init__(self, file_name="lsf_job.job", command=None, executable=None, arguments=None, - queue=None, cwd=None, input_files=None, output_files=None, postfix_output_files=True, - manual_stagein=False, manual_stageout=False, job_name=None, stdout="stdout.txt", - stderr="stderr.txt", shell="bash", emails=False, custom_content=None, - absolute_paths=False, **kwargs): + def __init__( + self, + *, + file_name: str = "lsf_job.job", + command: str | Sequence[str] | None = None, + executable: str | None = None, + arguments: str | Sequence[str] | None = None, + queue: str | None = None, + cwd: str | pathlib.Path | LocalDirectoryTarget | None = None, + input_files: dict[str, str | pathlib.Path | JobInputFile] | None = None, + output_files: list[str] | None = None, + postfix_output_files: bool = True, + manual_stagein: bool = False, + manual_stageout: bool = False, + job_name: str | None = None, + stdout: str = "stdout.txt", + stderr: str = "stderr.txt", + shell: str = "bash", + emails: bool = False, + custom_content: str | Sequence[str] | None = None, + absolute_paths: bool = False, + **kwargs, + ) -> None: # get some default kwargs from the config cfg = Config.instance() if kwargs.get("dir") is None: - kwargs["dir"] = cfg.get_expanded("job", cfg.find_option("job", - "lsf_job_file_dir", "job_file_dir")) + kwargs["dir"] = cfg.get_expanded( + "job", + cfg.find_option("job", "lsf_job_file_dir", "job_file_dir"), + ) if kwargs.get("mkdtemp") is None: - kwargs["mkdtemp"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "lsf_job_file_dir_mkdtemp", "job_file_dir_mkdtemp")) + kwargs["mkdtemp"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "lsf_job_file_dir_mkdtemp", "job_file_dir_mkdtemp"), + ) if kwargs.get("cleanup") is None: - kwargs["cleanup"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "lsf_job_file_dir_cleanup", "job_file_dir_cleanup")) + kwargs["cleanup"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "lsf_job_file_dir_cleanup", "job_file_dir_cleanup"), + ) - super(LSFJobFileFactory, self).__init__(**kwargs) + super().__init__(**kwargs) self.file_name = file_name self.command = command @@ -253,7 +315,7 @@ def __init__(self, file_name="lsf_job.job", command=None, executable=None, argum self.arguments = arguments self.queue = queue self.cwd = cwd - self.input_files = DeprecatedInputFiles(input_files or {}) + self.input_files = input_files or {} self.output_files = output_files or [] self.postfix_output_files = postfix_output_files self.manual_stagein = manual_stagein @@ -266,7 +328,11 @@ def __init__(self, file_name="lsf_job.job", command=None, executable=None, argum self.custom_content = custom_content self.absolute_paths = absolute_paths - def create(self, postfix=None, **kwargs): + def create( + self, + postfix: str | None = None, + **kwargs, + ) -> tuple[str, LSFJobFileFactory.Config]: # merge kwargs and instance attributes c = self.get_config(**kwargs) @@ -410,7 +476,7 @@ def prepare_input(f): os.chmod(path, os.stat(path).st_mode | stat.S_IXUSR | stat.S_IXGRP) # job file content - content = [] + content: list[str | tuple[str] | tuple[str, Any]] = [] content.append("#!/usr/bin/env {}".format(c.shell)) if c.job_name: @@ -418,7 +484,7 @@ def prepare_input(f): if c.queue: content.append(("-q", c.queue)) if c.cwd: - content.append(("-cwd", c.cwd)) + content.append(("-cwd", get_path(c.cwd))) if c.stdout: content.append(("-o", c.stdout)) if c.stderr: @@ -431,11 +497,11 @@ def prepare_input(f): if not c.manual_stagein: paths = [f.path_sub_rel for f in c.input_files.values() if f.path_sub_rel] for path in make_unique(paths): - content.append(("-f", "\"{} > {}\"".format(path, os.path.basename(path)))) + content.append(("-f", f"\"{path} > {os.path.basename(path)}\"")) if not c.manual_stageout: for path in make_unique(c.output_files): - content.append(("-f", "\"{} < {}\"".format(path, os.path.basename(path)))) + content.append(("-f", f"\"{path} < {os.path.basename(path)}\"")) if c.manual_stagein: tmpl = "cp " + ("{}" if c.absolute_paths else "$LS_EXECCWD/{}") + " $PWD/{}" @@ -449,7 +515,7 @@ def prepare_input(f): content.append("./" + c.executable) if c.arguments: args = quote_cmd(c.arguments) if isinstance(c.arguments, (list, tuple)) else c.arguments - content[-1] += " {}".format(args) + content[-1] += f" {args}" # type: ignore[operator] if c.manual_stageout: tmpl = "cp $PWD/{} $LS_EXECCWD/{}" @@ -459,17 +525,16 @@ def prepare_input(f): # write the job file with open(job_file, "w") as f: for line in content: - if not isinstance(line, six.string_types): + if not isinstance(line, str): line = self.create_line(*make_list(line)) - f.write(line + "\n") + f.write(f"{line}\n") - logger.debug("created lsf job file at '{}'".format(job_file)) + logger.debug(f"created lsf job file at '{job_file}'") return job_file, c @classmethod - def create_line(cls, key, value=None): + def create_line(cls, key: str, value: Any | None = None) -> str: if value is None: - return "#BSUB {}".format(key) - else: - return "#BSUB {} {}".format(key, value) + return f"#BSUB {key}" + return f"#BSUB {key} {value}" diff --git a/law/contrib/lsf/util.py b/law/contrib/lsf/util.py index 406a3870..099aa2f7 100644 --- a/law/contrib/lsf/util.py +++ b/law/contrib/lsf/util.py @@ -4,38 +4,46 @@ LSF utilities. """ -__all__ = ["get_lsf_version"] +from __future__ import annotations +__all__ = ["get_lsf_version"] import re import subprocess import threading -from law.util import no_value, interruptable_popen +from law.util import NoValue, no_value, interruptable_popen -_lsf_version = no_value +_lsf_version: tuple[int, int, int] | None | NoValue = no_value _lsf_version_lock = threading.Lock() -def get_lsf_version(): +def get_lsf_version() -> tuple[int, int, int] | None: """ Returns the version of the LSF installation in a 3-tuple. The value is cached to accelerate - repeated function invocations. + repeated function invocations. When the ``bjobs`` executable is not available, *None* is + returned. """ global _lsf_version if _lsf_version == no_value: version = None with _lsf_version_lock: - code, out, _ = interruptable_popen("bjobs -V", shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + "bjobs -V", + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) if code == 0: first_line = out.strip().split("\n")[0].strip() m = re.match(r"^Platform LSF (\d+)\.(\d+)\.(\d+).+$", first_line) if m: version = tuple(map(int, m.groups())) - _lsf_version = version + _lsf_version = version # type: ignore[assignment] - return _lsf_version + return _lsf_version # type: ignore[return-value] diff --git a/law/contrib/lsf/workflow.py b/law/contrib/lsf/workflow.py index 58a15811..916ea5ce 100644 --- a/law/contrib/lsf/workflow.py +++ b/law/contrib/lsf/workflow.py @@ -4,23 +4,25 @@ LSF remote workflow implementation. See https://www.ibm.com/support/knowledgecenter/en/SSETD4_9.1.3. """ -__all__ = ["LSFWorkflow"] +from __future__ import annotations +__all__ = ["LSFWorkflow"] import os -from abc import abstractmethod -from collections import OrderedDict +import abc +import pathlib -import luigi +import luigi # type: ignore[import-untyped] from law.workflow.remote import BaseRemoteWorkflow, BaseRemoteWorkflowProxy -from law.job.base import JobArguments, JobInputFile, DeprecatedInputFiles +from law.job.base import JobArguments, JobInputFile from law.task.proxy import ProxyCommand from law.target.file import get_path, get_scheme, FileSystemDirectoryTarget -from law.target.local import LocalDirectoryTarget +from law.target.local import LocalDirectoryTarget, LocalFileTarget from law.parameter import NO_STR -from law.util import law_src_path, merge_dicts, DotDict +from law.util import law_src_path, merge_dicts, DotDict, InsertableDict from law.logger import get_logger +from law._types import Type from law.contrib.lsf.job import LSFJobManager, LSFJobFileFactory @@ -30,23 +32,27 @@ class LSFWorkflowProxy(BaseRemoteWorkflowProxy): - workflow_type = "lsf" + workflow_type: str = "lsf" - def create_job_manager(self, **kwargs): + def create_job_manager(self, **kwargs) -> LSFJobManager: return self.task.lsf_create_job_manager(**kwargs) - def create_job_file_factory(self, **kwargs): + def create_job_file_factory(self, **kwargs) -> LSFJobFileFactory: return self.task.lsf_create_job_file_factory(**kwargs) - def create_job_file(self, job_num, branches): + def create_job_file( + self, + job_num: int, + branches: list[int], + ) -> dict[str, str | pathlib.Path | LSFJobFileFactory.Config | None]: task = self.task # the file postfix is pythonic range made from branches, e.g. [0, 1, 2, 4] -> "_0To5" - postfix = "_{}To{}".format(branches[0], branches[-1] + 1) + postfix = f"_{branches[0]}To{branches[-1] + 1}" # create the config - c = self.job_file_factory.get_config() - c.input_files = DeprecatedInputFiles() + c = self.job_file_factory.get_config() # type: ignore[union-attr] + c.input_files = {} c.output_files = [] c.render_variables = {} c.custom_content = [] @@ -76,18 +82,23 @@ def create_job_file(self, job_num, branches): ) if task.lsf_use_local_scheduler(): proxy_cmd.add_arg("--local-scheduler", "True", overwrite=True) - for key, value in OrderedDict(task.lsf_cmdline_args()).items(): + for key, value in dict(task.lsf_cmdline_args()).items(): proxy_cmd.add_arg(key, value, overwrite=True) # job script arguments + dashboard_data = None + if self.dashboard is not None: + dashboard_data = self.dashboard.remote_hook_data( + job_num, + self.job_data.attempts.get(job_num, 0), + ) job_args = JobArguments( task_cls=task.__class__, task_params=proxy_cmd.build(skip_run=True), branches=branches, workers=task.job_workers, auto_retry=False, - dashboard_data=self.dashboard.remote_hook_data( - job_num, self.job_data.attempts.get(job_num, 0)), + dashboard_data=dashboard_data, ) c.arguments = job_args.join() @@ -102,9 +113,10 @@ def create_job_file(self, job_num, branches): c.input_files["stageout_file"] = stageout_file # does the dashboard have a hook file? - dashboard_file = self.dashboard.remote_hook_file() - if dashboard_file: - c.input_files["dashboard_file"] = dashboard_file + if self.dashboard is not None: + dashboard_file = self.dashboard.remote_hook_file() + if dashboard_file: + c.input_files["dashboard_file"] = dashboard_file # logging # we do not use lsf's logging mechanism since it might require that the submission @@ -127,7 +139,7 @@ def create_job_file(self, job_num, branches): c.cwd = output_dir.abspath # job name - c.job_name = "{}{}".format(task.live_task_id, postfix) + c.job_name = f"{task.live_task_id}{postfix}" # task hook c = task.lsf_job_config(c, job_num, branches) @@ -137,7 +149,7 @@ def create_job_file(self, job_num, branches): del c.output_files[:] # build the job file and get the sanitized config - job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) + job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) # type: ignore[misc] # get the location of the custom local log file if any abs_log_file = None @@ -147,11 +159,11 @@ def create_job_file(self, job_num, branches): # return job and log files return {"job": job_file, "config": c, "log": abs_log_file} - def destination_info(self): - info = super(LSFWorkflowProxy, self).destination_info() + def destination_info(self) -> InsertableDict: + info = super().destination_info() if self.task.lsf_queue != NO_STR: - info["queue"] = "queue: {}".format(self.task.lsf_queue) + info["queue"] = f"queue: {self.task.lsf_queue}" info = self.task.lsf_destination_info(info) @@ -172,68 +184,73 @@ class LSFWorkflow(BaseRemoteWorkflow): description="target lsf queue; default: empty", ) - lsf_job_kwargs = ["lsf_queue"] + lsf_job_kwargs: list[str] = ["lsf_queue"] lsf_job_kwargs_submit = None lsf_job_kwargs_cancel = None lsf_job_kwargs_query = None exclude_params_branch = {"lsf_queue"} - exclude_params_lsf_workflow = set() + exclude_params_lsf_workflow: set[str] = set() exclude_index = True - @abstractmethod - def lsf_output_directory(self): - return None + @abc.abstractmethod + def lsf_output_directory(self) -> FileSystemDirectoryTarget: + ... - def lsf_bootstrap_file(self): + def lsf_workflow_requires(self) -> DotDict: + return DotDict() + + def lsf_bootstrap_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def lsf_wrapper_file(self): + def lsf_wrapper_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def lsf_job_file(self): + def lsf_job_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile: return JobInputFile(law_src_path("job", "law_job.sh")) - def lsf_stageout_file(self): + def lsf_stageout_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def lsf_workflow_requires(self): - return DotDict() - - def lsf_output_postfix(self): + def lsf_output_postfix(self) -> str: return "" - def lsf_job_manager_cls(self): + def lsf_job_manager_cls(self) -> Type[LSFJobManager]: return LSFJobManager - def lsf_create_job_manager(self, **kwargs): + def lsf_create_job_manager(self, **kwargs) -> LSFJobManager: kwargs = merge_dicts(self.lsf_job_manager_defaults, kwargs) return self.lsf_job_manager_cls()(**kwargs) - def lsf_job_file_factory_cls(self): + def lsf_job_file_factory_cls(self) -> Type[LSFJobFileFactory]: return LSFJobFileFactory - def lsf_create_job_file_factory(self, **kwargs): + def lsf_create_job_file_factory(self, **kwargs) -> LSFJobFileFactory: # job file fectory config priority: kwargs > class defaults kwargs = merge_dicts({}, self.lsf_job_file_factory_defaults, kwargs) return self.lsf_job_file_factory_cls()(**kwargs) - def lsf_job_config(self, config, job_num, branches): + def lsf_job_config( + self, + config: LSFJobFileFactory.Config, + job_num: int, + branches: list[int], + ) -> LSFJobFileFactory.Config: return config - def lsf_check_job_completeness(self): + def lsf_check_job_completeness(self) -> bool: return False - def lsf_check_job_completeness_delay(self): + def lsf_check_job_completeness_delay(self) -> float | int: return 0.0 - def lsf_use_local_scheduler(self): + def lsf_use_local_scheduler(self) -> bool: return True - def lsf_cmdline_args(self): + def lsf_cmdline_args(self) -> dict[str, str]: return {} - def lsf_destination_info(self, info): + def lsf_destination_info(self, info: InsertableDict) -> InsertableDict: return info diff --git a/law/contrib/slurm/__init__.py b/law/contrib/slurm/__init__.py index b086d31a..73a443b9 100644 --- a/law/contrib/slurm/__init__.py +++ b/law/contrib/slurm/__init__.py @@ -11,7 +11,6 @@ "SlurmWorkflow", ] - # provisioning imports from law.contrib.slurm.util import get_slurm_version from law.contrib.slurm.job import SlurmJobManager, SlurmJobFileFactory diff --git a/law/contrib/slurm/job.py b/law/contrib/slurm/job.py index 7ccb1743..d7314563 100644 --- a/law/contrib/slurm/job.py +++ b/law/contrib/slurm/job.py @@ -4,13 +4,15 @@ Slurm job manager. See https://slurm.schedmd.com/quickstart.html. """ -__all__ = ["SlurmJobManager", "SlurmJobFileFactory"] +from __future__ import annotations +__all__ = ["SlurmJobManager", "SlurmJobFileFactory"] import os import time import re import stat +import pathlib import subprocess from law.config import Config @@ -18,6 +20,7 @@ from law.target.file import get_path from law.util import interruptable_popen, make_list, quote_cmd from law.logger import get_logger +from law._types import Any, Sequence logger = get_logger(__name__) @@ -40,39 +43,54 @@ class SlurmJobManager(BaseJobManager): sacct_format = r"JobID,State,ExitCode,Reason" sacct_cre = re.compile(r"^\s*(\d+)\s+([^\s]+)\s+(-?\d+):-?\d+\s+(.+)$") - def __init__(self, partition=None, threads=1): - super(SlurmJobManager, self).__init__() + def __init__(self, partition: str | None = None, threads: int = 1) -> None: + super().__init__() self.partition = partition self.threads = threads - def cleanup(self, *args, **kwargs): + def cleanup(self, *args, **kwargs) -> None: # type: ignore[override] raise NotImplementedError("SlurmJobManager.cleanup is not implemented") - def cleanup_batch(self, *args, **kwargs): + def cleanup_batch(self, *args, **kwargs) -> None: # type: ignore[override] raise NotImplementedError("SlurmJobManager.cleanup_batch is not implemented") - def submit(self, job_file, partition=None, retries=0, retry_delay=3, silent=False): + def submit( # type: ignore[override] + self, + job_file: str | pathlib.Path, + partition: str | None = None, + retries: int = 0, + retry_delay: float | int = 3, + silent: bool = False, + ) -> int | None: # default arguments if partition is None: partition = self.partition # get the job file location as the submission command is run it the same directory - job_file_dir, job_file_name = os.path.split(os.path.abspath(str(job_file))) + job_file_dir, job_file_name = os.path.split(os.path.abspath(get_path(job_file))) # build the command cmd = ["sbatch"] if partition: cmd += ["--partition", partition] cmd += [job_file_name] - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # define the actual submission in a loop to simplify retries while True: # run the command - logger.debug("submit slurm job with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=job_file_dir) + logger.debug(f"submit slurm job with command '{cmd_str}'") + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=job_file_dir, + ) # get the job id(s) if code == 0: @@ -84,14 +102,13 @@ def submit(self, job_file, partition=None, retries=0, retry_delay=3, silent=Fals break else: code = 1 - err = "cannot parse slurm job id(s) from output:\n{}".format(out) + err = f"cannot parse slurm job id(s) from output:\n{out}" # retry or done? if code == 0: return job_id - logger.debug("submission of slurm job '{}' failed with code {}:\n{}".format( - job_file, code, err)) + logger.debug(f"submission of slurm job '{job_file}' failed with code {code}:\n{err}") if retries > 0: retries -= 1 @@ -101,10 +118,14 @@ def submit(self, job_file, partition=None, retries=0, retry_delay=3, silent=Fals if silent: return None - raise Exception("submission of slurm job '{}' failed:\n{}".format( - job_file, err)) + raise Exception(f"submission of slurm job '{job_file}' failed:\n{err}") - def cancel(self, job_id, partition=None, silent=False): + def cancel( # type: ignore[override] + self, + job_id: int | Sequence[int], + partition: str | None = None, + silent: bool = False, + ) -> dict[int, None] | None: # default arguments if partition is None: partition = self.partition @@ -117,21 +138,34 @@ def cancel(self, job_id, partition=None, silent=False): if partition: cmd += ["--partition", partition] cmd += job_ids - cmd = quote_cmd(cmd) + cmd_str = quote_cmd(cmd) # run it - logger.debug("cancel slurm job(s) with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + logger.debug(f"cancel slurm job(s) with command '{cmd_str}'") + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd_str, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # check success if code != 0 and not silent: - raise Exception("cancellation of slurm job(s) '{}' failed with code {}:\n{}".format( - job_id, code, err)) + raise Exception( + f"cancellation of slurm job(s) '{job_id}' failed with code {code}:\n{err}", + ) return {job_id: None for job_id in job_ids} if chunking else None - def query(self, job_id, partition=None, silent=False): + def query( # type: ignore[override] + self, + job_id: int | Sequence[int], + partition: str | None = None, + silent: bool = False, + ) -> dict[int, dict[str, Any]] | dict[str, Any] | None: # default arguments if partition is None: partition = self.partition @@ -144,11 +178,18 @@ def query(self, job_id, partition=None, silent=False): if partition: cmd += ["--partition", partition] cmd += ["--jobs", ",".join(map(str, job_ids))] - cmd = quote_cmd(cmd) - - logger.debug("query slurm job(s) with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + cmd_str = quote_cmd(cmd) + + logger.debug(f"query slurm job(s) with command '{cmd_str}'") + out: str + err: str + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # special case: when the id of a single yet expired job is queried, squeue responds with an # error (exit code != 0), so as a workaround, consider these cases as an empty result @@ -161,9 +202,9 @@ def query(self, job_id, partition=None, silent=False): if code != 0: if silent: return None - else: - raise Exception("queue query of slurm job(s) '{}' failed with code {}:" - "\n{}".format(job_id, code, err)) + raise Exception( + f"queue query of slurm job(s) '{job_id}' failed with code {code}:\n{err}", + ) # parse the output and extract the status per job query_data = self.parse_squeue_output(out) @@ -176,19 +217,24 @@ def query(self, job_id, partition=None, silent=False): if partition: cmd += ["--partition", partition] cmd += ["--jobs", ",".join(map(str, missing_ids))] - cmd = quote_cmd(cmd) - - logger.debug("query slurm accounting history with command '{}'".format(cmd)) - code, out, err = interruptable_popen(cmd, shell=True, executable="/bin/bash", - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + cmd_str = quote_cmd(cmd) + + logger.debug(f"query slurm accounting history with command '{cmd_str}'") + code, out, err = interruptable_popen( # type: ignore[assignment] + cmd, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) # handle errors if code != 0: if silent: return None - else: - raise Exception("accounting query of slurm job(s) '{}' failed with code {}:" - "\n{}".format(job_id, code, err)) + raise Exception( + f"accounting query of slurm job(s) '{job_id}' failed with code {code}:\n{err}", + ) # parse the output and update query data query_data.update(self.parse_sacct_output(out)) @@ -199,17 +245,18 @@ def query(self, job_id, partition=None, silent=False): if not chunking: if silent: return None - else: - raise Exception("slurm job(s) '{}' not found in query response".format( - job_id)) + raise Exception(f"slurm job(s) '{job_id}' not found in query response") else: - query_data[_job_id] = self.job_status_dict(job_id=_job_id, status=self.FAILED, - error="job not found in query response") + query_data[_job_id] = self.job_status_dict( + job_id=_job_id, + status=self.FAILED, + error="job not found in query response", + ) - return query_data if chunking else query_data[job_id] + return query_data if chunking else query_data[job_id] # type: ignore[index] @classmethod - def parse_squeue_output(cls, out): + def parse_squeue_output(cls, out: str) -> dict[int, dict[str, Any]]: # retrieve information per block mapped to the job id query_data = {} for line in out.strip().split("\n"): @@ -229,7 +276,7 @@ def parse_squeue_output(cls, out): return query_data @classmethod - def parse_sacct_output(cls, out): + def parse_sacct_output(cls, out: str) -> dict[int, dict[str, Any]]: # retrieve information per block mapped to the job id query_data = {} for line in out.strip().split("\n"): @@ -255,35 +302,40 @@ def parse_sacct_output(cls, out): if code != 0 and status != cls.FAILED: status = cls.FAILED if not error: - error = "job status set to '{}' due to non-zero exit code {}".format( - cls.FAILED, code) + error = f"job status set to '{cls.FAILED}' due to non-zero exit code {code}" if not error and status == cls.FAILED: error = m.group(2) # store it - query_data[job_id] = cls.job_status_dict(job_id=job_id, status=status, code=code, - error=error) + query_data[job_id] = cls.job_status_dict( + job_id=job_id, + status=status, + code=code, + error=error, + ) return query_data @classmethod - def map_status(cls, status): + def map_status(cls, status: str | None) -> str: # see https://slurm.schedmd.com/squeue.html#lbAG - status = status.strip("+") + if isinstance(status, str): + status = status.strip("+") if status in ["CONFIGURING", "PENDING", "REQUEUED", "REQUEUE_HOLD", "REQUEUE_FED"]: return cls.PENDING - elif status in ["RUNNING", "COMPLETING", "STAGE_OUT"]: + if status in ["RUNNING", "COMPLETING", "STAGE_OUT"]: return cls.RUNNING - elif status in ["COMPLETED"]: + if status in ["COMPLETED"]: return cls.FINISHED - elif status in ["BOOT_FAIL", "CANCELLED", "DEADLINE", "FAILED", "NODE_FAIL", - "OUT_OF_MEMORY", "PREEMPTED", "REVOKED", "SPECIAL_EXIT", "STOPPED", "SUSPENDED", - "TIMEOUT"]: - return cls.FAILED - else: - logger.debug("unknown slurm job state '{}'".format(status)) + if status in [ + "BOOT_FAIL", "CANCELLED", "DEADLINE", "FAILED", "NODE_FAIL", "OUT_OF_MEMORY", + "PREEMPTED", "REVOKED", "SPECIAL_EXIT", "STOPPED", "SUSPENDED", "TIMEOUT", + ]: return cls.FAILED + logger.debug(f"unknown slurm job state '{status}'") + return cls.FAILED + class SlurmJobFileFactory(BaseJobFileFactory): @@ -292,23 +344,43 @@ class SlurmJobFileFactory(BaseJobFileFactory): "partition", "stdout", "stderr", "postfix_output_files", "custom_content", "absolute_paths", ] - def __init__(self, file_name="slurm_job.sh", command=None, executable=None, arguments=None, - shell="bash", input_files=None, job_name=None, partition=None, stdout="stdout.txt", - stderr="stderr.txt", postfix_output_files=True, custom_content=None, - absolute_paths=False, **kwargs): + def __init__( + self, + *, + file_name: str = "slurm_job.sh", + command: str | Sequence[str] | None = None, + executable: str | None = None, + arguments: str | Sequence[str] | None = None, + shell: str = "bash", + input_files: dict[str, str | pathlib.Path | JobInputFile] | None = None, + job_name: str | None = None, + partition: str | None = None, + stdout: str = "stdout.txt", + stderr: str = "stderr.txt", + postfix_output_files: bool = True, + custom_content: str | Sequence[str] | None = None, + absolute_paths: bool = False, + **kwargs, + ) -> None: # get some default kwargs from the config cfg = Config.instance() if kwargs.get("dir") is None: - kwargs["dir"] = cfg.get_expanded("job", cfg.find_option("job", - "slurm_job_file_dir", "job_file_dir")) + kwargs["dir"] = cfg.get_expanded( + "job", + cfg.find_option("job", "slurm_job_file_dir", "job_file_dir"), + ) if kwargs.get("mkdtemp") is None: - kwargs["mkdtemp"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "slurm_job_file_dir_mkdtemp", "job_file_dir_mkdtemp")) + kwargs["mkdtemp"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "slurm_job_file_dir_mkdtemp", "job_file_dir_mkdtemp"), + ) if kwargs.get("cleanup") is None: - kwargs["cleanup"] = cfg.get_expanded_bool("job", cfg.find_option("job", - "slurm_job_file_dir_cleanup", "job_file_dir_cleanup")) + kwargs["cleanup"] = cfg.get_expanded_bool( + "job", + cfg.find_option("job", "slurm_job_file_dir_cleanup", "job_file_dir_cleanup"), + ) - super(SlurmJobFileFactory, self).__init__(**kwargs) + super().__init__(**kwargs) self.file_name = file_name self.command = command @@ -324,7 +396,11 @@ def __init__(self, file_name="slurm_job.sh", command=None, executable=None, argu self.custom_content = custom_content self.absolute_paths = absolute_paths - def create(self, postfix=None, **kwargs): + def create( + self, + postfix: str | None = None, + **kwargs, + ) -> tuple[str, SlurmJobFileFactory.Config]: # merge kwargs and instance attributes c = self.get_config(**kwargs) @@ -459,8 +535,8 @@ def prepare_input(f): os.chmod(path, os.stat(path).st_mode | stat.S_IXUSR | stat.S_IXGRP) # job file content - content = [] - content.append("#!/usr/bin/env {}".format(c.shell)) + content: list[str | tuple[str, Any]] = [] + content.append(f"#!/usr/bin/env {c.shell}") content.append("") if c.job_name: @@ -480,7 +556,7 @@ def prepare_input(f): with open(job_file, "w") as f: for obj in content: line = self.create_line(obj) - f.write(line + "\n") + f.write(f"{line}\n") # prepare arguments args = c.arguments or "" @@ -490,17 +566,17 @@ def prepare_input(f): # add the command if c.command: cmd = quote_cmd(c.command) if isinstance(c.command, (list, tuple)) else c.command - f.write("\n{}{}\n".format(cmd.strip(), args)) + f.write(f"\n{cmd.strip()}{args}\n") # add the executable if c.executable: cmd = c.executable - f.write("\n{}{}\n".format(cmd, args)) + f.write(f"\n{cmd}{args}\n") # make it executable os.chmod(job_file, os.stat(job_file).st_mode | stat.S_IXUSR | stat.S_IXGRP) - logger.debug("created slurm job file at '{}'".format(job_file)) + logger.debug(f"created slurm job file at '{job_file}'") return job_file, c @@ -509,9 +585,9 @@ def create_line(cls, args): _str = lambda s: str(s).strip() if not isinstance(args, (list, tuple)): return args.strip() - elif len(args) == 1: - return "#SBATCH --{}".format(*map(_str, args)) - elif len(args) == 2: - return "#SBATCH --{}={}".format(*map(_str, args)) - else: - raise Exception("cannot create job file line from '{}'".format(args)) + if len(args) == 1: + return f"#SBATCH --{_str(args[0])}" + if len(args) == 2: + return f"#SBATCH --{_str(args[0])}={_str(args[1])}" + + raise Exception(f"cannot create job file line from '{args}'") diff --git a/law/contrib/slurm/util.py b/law/contrib/slurm/util.py index 0528a1c4..e9f3026c 100644 --- a/law/contrib/slurm/util.py +++ b/law/contrib/slurm/util.py @@ -4,38 +4,46 @@ Slurm utilities. """ -__all__ = ["get_slurm_version"] +from __future__ import annotations +__all__ = ["get_slurm_version"] import re import subprocess import threading -from law.util import no_value, interruptable_popen +from law.util import NoValue, no_value, interruptable_popen -_slurm_version = no_value +_slurm_version: tuple[int, int, int] | None | NoValue = no_value _slurm_version_lock = threading.Lock() -def get_slurm_version(): +def get_slurm_version() -> tuple[int, int, int] | None: """ Returns the version of the Slurm installation in a 3-tuple. The value is cached to accelerate - repeated function invocations. + repeated function invocations. When the ``sbatch`` executable is not available, *None* is + returned. """ global _slurm_version if _slurm_version == no_value: version = None with _slurm_version_lock: - code, out, _ = interruptable_popen("sbatch --version", shell=True, - executable="/bin/bash", stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out: str + code, out, _ = interruptable_popen( # type: ignore[assignment] + "sbatch --version", + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) if code == 0: first_line = out.strip().split("\n")[0] m = re.match(r"^slurm (\d+)\.(\d+)\.(\d+).*$", first_line.strip()) if m: version = tuple(map(int, m.groups())) - _slurm_version = version + _slurm_version = version # type: ignore[assignment] - return _slurm_version + return _slurm_version # type: ignore[return-value] diff --git a/law/contrib/slurm/workflow.py b/law/contrib/slurm/workflow.py index 20dc4c8e..2ce9d260 100644 --- a/law/contrib/slurm/workflow.py +++ b/law/contrib/slurm/workflow.py @@ -4,23 +4,25 @@ Slurm workflow implementation. See https://slurm.schedmd.com. """ -__all__ = ["SlurmWorkflow"] +from __future__ import annotations +__all__ = ["SlurmWorkflow"] import os -from abc import abstractmethod -from collections import OrderedDict +import abc +import pathlib -import luigi +import luigi # type: ignore[import-untyped] from law.workflow.remote import BaseRemoteWorkflow, BaseRemoteWorkflowProxy from law.job.base import JobArguments, JobInputFile from law.task.proxy import ProxyCommand from law.target.file import get_path, get_scheme, FileSystemDirectoryTarget -from law.target.local import LocalDirectoryTarget +from law.target.local import LocalDirectoryTarget, LocalFileTarget from law.parameter import NO_STR -from law.util import law_src_path, merge_dicts, DotDict +from law.util import law_src_path, merge_dicts, DotDict, InsertableDict from law.logger import get_logger +from law._types import Type from law.contrib.slurm.job import SlurmJobManager, SlurmJobFileFactory @@ -30,22 +32,26 @@ class SlurmWorkflowProxy(BaseRemoteWorkflowProxy): - workflow_type = "slurm" + workflow_type: str = "slurm" - def create_job_manager(self, **kwargs): + def create_job_manager(self, **kwargs) -> SlurmJobManager: return self.task.slurm_create_job_manager(**kwargs) - def create_job_file_factory(self, **kwargs): + def create_job_file_factory(self, **kwargs) -> SlurmJobFileFactory: return self.task.slurm_create_job_file_factory(**kwargs) - def create_job_file(self, job_num, branches): + def create_job_file( + self, + job_num: int, + branches: list[int], + ) -> dict[str, str | pathlib.Path | SlurmJobFileFactory.Config | None]: task = self.task # the file postfix is pythonic range made from branches, e.g. [0, 1, 2, 4] -> "_0To5" - postfix = "_{}To{}".format(branches[0], branches[-1] + 1) + postfix = f"_{branches[0]}To{branches[-1] + 1}" # create the config - c = self.job_file_factory.get_config() + c = self.job_file_factory.get_config() # type: ignore[union-attr] c.input_files = {} c.render_variables = {} c.custom_content = [] @@ -75,18 +81,23 @@ def create_job_file(self, job_num, branches): ) if task.slurm_use_local_scheduler(): proxy_cmd.add_arg("--local-scheduler", "True", overwrite=True) - for key, value in OrderedDict(task.slurm_cmdline_args()).items(): + for key, value in dict(task.slurm_cmdline_args()).items(): proxy_cmd.add_arg(key, value, overwrite=True) # job script arguments + dashboard_data = None + if self.dashboard is not None: + dashboard_data = self.dashboard.remote_hook_data( + job_num, + self.job_data.attempts.get(job_num, 0), + ) job_args = JobArguments( task_cls=task.__class__, task_params=proxy_cmd.build(skip_run=True), branches=branches, workers=task.job_workers, auto_retry=False, - dashboard_data=self.dashboard.remote_hook_data( - job_num, self.job_data.attempts.get(job_num, 0)), + dashboard_data=dashboard_data, ) c.arguments = job_args.join() @@ -101,9 +112,10 @@ def create_job_file(self, job_num, branches): c.input_files["stageout_file"] = stageout_file # does the dashboard have a hook file? - dashboard_file = self.dashboard.remote_hook_file() - if dashboard_file: - c.input_files["dashboard_file"] = dashboard_file + if self.dashboard is not None: + dashboard_file = self.dashboard.remote_hook_file() + if dashboard_file: + c.input_files["dashboard_file"] = dashboard_file # logging # we do not use slurm's logging mechanism since it might require that the submission @@ -126,7 +138,7 @@ def create_job_file(self, job_num, branches): c.custom_content.append(("chdir", output_dir.abspath)) # job name - c.job_name = "{}{}".format(task.live_task_id, postfix) + c.job_name = f"{task.live_task_id}{postfix}" # task arguments if task.slurm_partition and task.slurm_partition != NO_STR: @@ -136,7 +148,7 @@ def create_job_file(self, job_num, branches): c = task.slurm_job_config(c, job_num, branches) # build the job file and get the sanitized config - job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) + job_file, c = self.job_file_factory(postfix=postfix, **c.__dict__) # type: ignore[misc] # get the location of the custom local log file if any abs_log_file = None @@ -146,8 +158,8 @@ def create_job_file(self, job_num, branches): # return job and log files return {"job": job_file, "config": c, "log": abs_log_file} - def destination_info(self): - info = super(SlurmWorkflowProxy, self).destination_info() + def destination_info(self) -> InsertableDict: + info = super().destination_info() info = self.task.slurm_destination_info(info) @@ -168,68 +180,73 @@ class SlurmWorkflow(BaseRemoteWorkflow): description="target queue partition; default: empty", ) - slurm_job_kwargs = ["slurm_partition"] + slurm_job_kwargs: list[str] = ["slurm_partition"] slurm_job_kwargs_submit = None slurm_job_kwargs_cancel = None slurm_job_kwargs_query = None exclude_params_branch = {"slurm_partition"} - exclude_params_slurm_workflow = set() + exclude_params_slurm_workflow: set[str] = set() exclude_index = True - @abstractmethod - def slurm_output_directory(self): - return None + @abc.abstractmethod + def slurm_output_directory(self) -> FileSystemDirectoryTarget: + ... - def slurm_workflow_requires(self): + def slurm_workflow_requires(self) -> DotDict: return DotDict() - def slurm_bootstrap_file(self): + def slurm_bootstrap_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def slurm_wrapper_file(self): + def slurm_wrapper_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def slurm_job_file(self): + def slurm_job_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile: return JobInputFile(law_src_path("job", "law_job.sh")) - def slurm_stageout_file(self): + def slurm_stageout_file(self) -> str | pathlib.Path | LocalFileTarget | JobInputFile | None: return None - def slurm_output_postfix(self): + def slurm_output_postfix(self) -> str: return "" - def slurm_job_manager_cls(self): + def slurm_job_manager_cls(self) -> Type[SlurmJobManager]: return SlurmJobManager - def slurm_create_job_manager(self, **kwargs): + def slurm_create_job_manager(self, **kwargs) -> SlurmJobManager: kwargs = merge_dicts(self.slurm_job_manager_defaults, kwargs) return self.slurm_job_manager_cls()(**kwargs) - def slurm_job_file_factory_cls(self): + def slurm_job_file_factory_cls(self) -> Type[SlurmJobFileFactory]: return SlurmJobFileFactory - def slurm_create_job_file_factory(self, **kwargs): + def slurm_create_job_file_factory(self, **kwargs) -> SlurmJobFileFactory: # job file fectory config priority: kwargs > class defaults kwargs = merge_dicts({}, self.slurm_job_file_factory_defaults, kwargs) return self.slurm_job_file_factory_cls()(**kwargs) - def slurm_job_config(self, config, job_num, branches): + def slurm_job_config( + self, + config: SlurmJobFileFactory.Config, + job_num: int, + branches: list[int], + ) -> SlurmJobFileFactory.Config: return config - def slurm_check_job_completeness(self): + def slurm_check_job_completeness(self) -> bool: return False - def slurm_check_job_completeness_delay(self): + def slurm_check_job_completeness_delay(self) -> float | int: return 0.0 - def slurm_use_local_scheduler(self): + def slurm_use_local_scheduler(self) -> bool: return False - def slurm_cmdline_args(self): + def slurm_cmdline_args(self) -> dict[str, str]: return {} - def slurm_destination_info(self, info): + def slurm_destination_info(self, info: InsertableDict) -> InsertableDict: return info diff --git a/law/contrib/wlcg/__init__.py b/law/contrib/wlcg/__init__.py index 018e3b25..d0eb47dc 100644 --- a/law/contrib/wlcg/__init__.py +++ b/law/contrib/wlcg/__init__.py @@ -14,6 +14,10 @@ "ensure_vomsproxy", ] +# dependencies to other contrib modules +import law +law.contrib.load("gfal") + # provisioning imports from law.contrib.wlcg.util import ( get_userkey, get_usercert, get_usercert_subject, diff --git a/law/contrib/wlcg/target.py b/law/contrib/wlcg/target.py index c10e2d4f..3acb66f2 100644 --- a/law/contrib/wlcg/target.py +++ b/law/contrib/wlcg/target.py @@ -19,8 +19,6 @@ logger = get_logger(__name__) -law.contrib.load("gfal") - class WLCGFileSystem(RemoteFileSystem): diff --git a/law/contrib/wlcg/util.py b/law/contrib/wlcg/util.py index 84124580..3e37d267 100644 --- a/law/contrib/wlcg/util.py +++ b/law/contrib/wlcg/util.py @@ -286,7 +286,7 @@ def delegate_vomsproxy_glite( stdout: int | TextIO | None = None, stderr: int | TextIO | None = None, cache: bool | str | pathlib.Path = True, -): +) -> str: """ Delegates the voms proxy via gLite to an *endpoint*, e.g. ``grid-ce.physik.rwth-aachen.de:8443``. When *proxy_file* is *None*, it defaults to the result diff --git a/law/job/base.py b/law/job/base.py index 667ad7f1..b09b9e72 100644 --- a/law/job/base.py +++ b/law/job/base.py @@ -23,15 +23,16 @@ from threading import Lock from abc import ABCMeta, abstractmethod -from law.task.base import Register +from law.task.base import Task from law.target.file import get_scheme, get_path +from law.target.local import LocalFileTarget from law.target.remote.base import RemoteTarget from law.config import Config from law.util import ( colored, make_list, make_tuple, iter_chunks, makedirs, create_hash, empty_context, ) from law.logger import get_logger -from law._types import Any, Callable, Hashable, Sequence, TracebackType +from law._types import Any, Callable, Hashable, Sequence, TracebackType, Type, T logger = get_logger(__name__) @@ -180,7 +181,7 @@ class BaseJobManager(object, metaclass=ABCMeta): @classmethod def job_status_dict( cls, - job_id: str | None = None, + job_id: Any | None = None, status: str | None = None, code: int | None = None, error: str | None = None, @@ -193,7 +194,7 @@ def job_status_dict( return dict(job_id=job_id, status=status, code=code, error=error, extra=extra) @classmethod - def cast_job_id(cls, job_id: Any) -> str: + def cast_job_id(cls, job_id: Any) -> Any: """ Hook for casting an input *job_id*, for instance, after loading serialized data from json. """ @@ -214,16 +215,16 @@ def __init__( self.last_counts = [0] * len(self.status_names) @abstractmethod - def submit(self) -> list[str]: + def submit(self) -> Any: """ Abstract atomic or group job submission. Can throw exceptions. - Should return a list of job ids. + Should return a single job id or a list of ids. """ ... @abstractmethod - def cancel(self) -> dict[str, Any]: + def cancel(self) -> dict[Any, Any]: """ Abstract atomic or group job cancellation. Can throw exceptions. @@ -232,7 +233,7 @@ def cancel(self) -> dict[str, Any]: ... @abstractmethod - def cleanup(self) -> dict[str, Any]: + def cleanup(self) -> dict[Any, Any]: """ Abstract atomic or group job cleanup. Can throw exceptions. @@ -241,7 +242,7 @@ def cleanup(self) -> dict[str, Any]: ... @abstractmethod - def query(self) -> dict[str, Any]: + def query(self) -> dict[Any, Any]: """ Abstract atomic or group job status query. Can throw exceptions. @@ -249,7 +250,7 @@ def query(self) -> dict[str, Any]: """ ... - def group_job_ids(self, job_ids: list[Any]) -> dict[Any, list[Any]]: + def group_job_ids(self, job_ids: list[Any]) -> dict[Hashable, list[Any]]: """ Hook that needs to be implemented if the job mananger supports grouping of jobs, i.e., when :py:attr:`job_grouping` is *True*, and potentially used during status queries, job @@ -332,6 +333,7 @@ def wrapper(data): def submit_batch( self, job_files: list[Any], + *, threads: int | None = None, chunk_size: int | None = None, callback: Callable[[int, Any], Any] | None = None, @@ -366,6 +368,7 @@ def submit_batch( def cancel_batch( self, job_ids: list[Hashable], + *, threads: int | None = None, chunk_size: int | None = None, callback: Callable[[int, Any], Any] | None = None, @@ -401,6 +404,7 @@ def cancel_batch( def cleanup_batch( self, job_ids: list[Hashable], + *, threads: int | None = None, chunk_size: int | None = None, callback: Callable[[int, Any], Any] | None = None, @@ -437,6 +441,7 @@ def cleanup_batch( def query_batch( self, job_ids: list[Hashable], + *, threads: int | None = None, chunk_size: int | None = None, callback: Callable[[int, Any], Any] | None = None, @@ -469,18 +474,18 @@ def query_batch( def _apply_group( self, func: Callable, - result_type: type, - group_func, + result_type: Type[T], + group_func: Callable[[list[Any]], dict[Hashable, list[Any]]], job_objs: list[Any], threads: int | None = None, callback: Callable[[int, Any], Any] | None = None, **kwargs, - ) -> Any: + ) -> T: # default arguments threads = max(threads or self.threads or 1, 1) # group objects - job_obj_groups: dict[Any, list[Any]] = group_func(make_list(job_objs)) + job_obj_groups: dict[Hashable, list[Any]] = group_func(make_list(job_objs)) # factory to call the passed callback for each job file even when chunking def cb_factory(i: int) -> Callable | None: @@ -512,13 +517,14 @@ def wrapper(result_data: Any) -> None: if isinstance(result_data, list): result_data.append(data if isinstance(data, Exception) else data[i]) else: - result_data[job_obj] = data if isinstance(data, Exception) else data[job_obj] + result_data[job_obj] = data if isinstance(data, Exception) else data[job_obj] # type: ignore[index] # noqa return result_data def submit_group( self, job_files: list[Any], + *, threads: int | None = None, callback: Callable[[int, Any], Any] | None = None, **kwargs, @@ -538,7 +544,7 @@ def submit_group( added to the returned list. """ # in order to use the generic grouping mechanism in _apply_group create a trivial group_func - def group_func(job_files: list[Any]) -> dict[Any, list[Any]]: + def group_func(job_files: list[Any]) -> dict[Hashable, list[Any]]: groups = defaultdict(list) for job_file in job_files: groups[job_file].append(job_file) @@ -557,6 +563,7 @@ def group_func(job_files: list[Any]) -> dict[Any, list[Any]]: def cancel_group( self, job_ids: list[Hashable], + *, threads: int | None = None, callback: Callable[[int, Any], Any] | None = None, **kwargs, @@ -580,7 +587,7 @@ def cancel_group( job_objs=job_ids, threads=threads, callback=callback, - **kwargs # noqa + **kwargs, ) # return only errors @@ -589,6 +596,7 @@ def cancel_group( def cleanup_group( self, job_ids: list[Hashable], + *, threads: int | None = None, callback: Callable[[int, Any], Any] | None = None, **kwargs, @@ -621,6 +629,7 @@ def cleanup_group( def query_group( self, job_ids: list[Hashable], + *, threads: int | None = None, callback: Callable[[int, Any], Any] | None = None, **kwargs, @@ -651,6 +660,7 @@ def status_line( self, counts: Sequence[int], last_counts: Sequence[int] | bool | None = None, + *, sum_counts: int | None = None, timestamp: bool = True, align: bool | int = False, @@ -812,6 +822,7 @@ def __contains__(self, attr: str) -> bool: def __init__( self, + *, dir: str | pathlib.Path | None = None, render_variables: dict[str, Any] | None = None, custom_log_file: str | pathlib.Path | None = None, @@ -870,6 +881,7 @@ def postfix_file( cls, path: str | pathlib.Path, postfix: str | dict[str, str] | None = None, + *, add_hash: bool = False, ) -> str: """ @@ -993,6 +1005,7 @@ def render_file( src: str | pathlib.Path, dst: str | pathlib.Path, render_variables: dict[str, Any], + *, postfix: str | dict[str, str] | None = None, silent: bool = True, ) -> None: @@ -1043,6 +1056,7 @@ def postfix_fn(m: re.Match) -> str: def provide_input( self, src: str | pathlib.Path, + *, postfix: str | dict[str, str] | None = None, dir: str | pathlib.Path | None = None, render_variables: dict[str, Any] | None = None, @@ -1177,13 +1191,14 @@ class JobArguments(object): def __init__( self, - task_cls: Register, + *, + task_cls: Type[Task], task_params: str, branches: list[int], workers: int = 1, auto_retry: bool = False, - dashboard_data: list[str] | None = None, - ): + dashboard_data: dict[str, Any] | None = None, + ) -> None: super().__init__() self.task_cls = task_cls @@ -1191,7 +1206,7 @@ def __init__( self.branches = branches self.workers = max(workers, 1) self.auto_retry = auto_retry - self.dashboard_data: list[str] = dashboard_data or [] + self.dashboard_data: dict[str, Any] = dashboard_data or {} @classmethod def encode_bool(cls, b: bool) -> str: @@ -1209,13 +1224,27 @@ def encode_string(cls, s: str) -> str: return encoded.decode("utf-8") @classmethod - def encode_list(cls, l: list) -> str: + def encode_list(cls, l: list[Any]) -> str: """ Encodes a list *l* into a string via base64 encoding. """ - encoded = base64.b64encode((" ".join(map(str, l)) or "-").encode("utf-8")) + # none of the elements in l must have a space in their string representation + l_str = list(map(str, l)) + for s in l_str: + if " " in s: + raise ValueError(f"cannot encode list element containing spaces: {l_str}") + + encoded = base64.b64encode((" ".join(l_str) or "-").encode("utf-8")) return encoded.decode("utf-8") + @classmethod + def encode_dict(cls, d: dict) -> str: + """ + Encodes a dict *d* into a string representation "key1=value1 key2=value2" via base64 + encoding. + """ + return cls.encode_list([f"{k}={v}" for k, v in d.items()]) + def get_args(self) -> list[str]: """ Returns the list of encoded job arguments. The order of this list corresponds to the @@ -1228,7 +1257,7 @@ def get_args(self) -> list[str]: self.encode_list(self.branches), str(self.workers), self.encode_bool(self.auto_retry), - self.encode_list(self.dashboard_data), + self.encode_dict(self.dashboard_data), ] def join(self) -> str: @@ -1330,7 +1359,8 @@ class JobInputFile(object): def __init__( self, - path: str | pathlib.Path | JobInputFile, + path: str | pathlib.Path | JobInputFile | LocalFileTarget, + *, copy: bool | None = None, share: bool | None = None, forward: bool | None = None, diff --git a/law/job/dashboard.py b/law/job/dashboard.py index 28cb4ac7..d37cf14e 100644 --- a/law/job/dashboard.py +++ b/law/job/dashboard.py @@ -17,15 +17,15 @@ def cache_by_status( - func: Callable[[Any, dict, str, int], Any], -) -> Callable[[dict, str, int], Any]: + func: Callable[[Any, JobData, str, int], Any], +) -> Callable[[JobData, str, int], Any]: """ Decorator for :py:meth:`BaseJobDashboard.publish` (and inheriting classes) that caches the last published status to decide if the a new publication is necessary or not. When the status did not change since the last call, the actual publish method is not invoked and *None* is returned. """ @functools.wraps(func) - def wrapper(self, job_data: dict, event: str, job_num: int, *args, **kwargs) -> None | Any: + def wrapper(self, job_data: JobData, event: str, job_num: int, *args, **kwargs) -> Any | None: job_id = job_data["job_id"] dashboard_status = self.map_status(job_data.get("status"), event) @@ -165,7 +165,7 @@ def map_status(self, job_status: str, event: str) -> str | None: ... @abstractmethod - def publish(self, job_data: dict, event: str, job_num: int, *args, **kwargs) -> None: + def publish(self, job_data: JobData, event: str, job_num: int) -> None: """ Publishes the status of a job to the implemented job dashboard. *job_data* is a dictionary that contains a *job_id* and a *status* string (see @@ -182,9 +182,17 @@ class NoJobDashboard(BaseJobDashboard): """ def map_status(self, *args, **kwargs) -> str | None: - """""" + """ + Returns *None*. + """ return None def publish(self, *args, **kwargs) -> None: - """""" - return + """ + Returns *None*. + """ + return None + + +# trailing imports +from law.workflow.remote import JobData diff --git a/law/util.py b/law/util.py index ac75bbda..5cf28ea5 100644 --- a/law/util.py +++ b/law/util.py @@ -2088,7 +2088,7 @@ class DotDict(dict): """ def __class_getitem__(cls, types: tuple[type, type]) -> GenericAlias: - # python < 3.9 + # python <3.9 if GenericAlias is str: key_type, value_type = types return f"{cls.__name__}[{key_type.__name__}, {value_type.__name__}]" # type: ignore[return-value] # noqa diff --git a/law/workflow/remote.py b/law/workflow/remote.py index 7c7081d4..a797052d 100644 --- a/law/workflow/remote.py +++ b/law/workflow/remote.py @@ -293,24 +293,33 @@ def create_job_file_factory(self, **kwargs) -> BaseJobFileFactory: """ ... - @abstractmethod def create_job_file( self, - *args, - **kwargs, - ) -> dict[str, str | pathlib.Path | BaseJobFileFactory.Config]: + job_num: int, + branches: list[int], + ) -> dict[str, str | pathlib.Path | BaseJobFileFactory.Config | None]: """ - Creates a job file using the :py:attr:`job_file_factory`. The expected arguments depend on - wether the job manager supports job grouping (:py:attr:`BaseJobManager.job_grouping`). If it - does, two arguments containing the job number (*job_num*) and the list of branch numbers - (*branches*) covered by the job. If job grouping is supported, a single dictionary mapping - job numbers to covered branch values must be passed. In any case, the path(s) of job files - are returned. + Creates a job file using the :py:attr:`job_file_factory`. The path(s) of job files are + returned. This method must be implemented by inheriting classes. """ # TODO: add TypedDict or similar as return type - ... + raise NotImplementedError() + + def create_job_file_group( + self, + submit_jobs: dict[int, list[int]], + ) -> dict[str, str | pathlib.Path | BaseJobFileFactory.Config | None]: + """ + Creates a job file using the :py:attr:`job_file_factory` based on a group of *submit_jobs*. + This method should be implemented in case the corresponding job manager supports job + grouping (:py:attr:`BaseJobManager.job_grouping`). The path(s) of job files are returned. + + This method must be implemented by inheriting classes. + """ + # TODO: add TypedDict or similar as return type + raise NotImplementedError() def destination_info(self) -> InsertableDict: """ @@ -956,7 +965,7 @@ def _submit_group( task = self.task # create the single multi submission file, passing the job_num -> branches dict - job_file = self.create_job_file(submit_jobs) + job_file = self.create_job_file_group(submit_jobs) # setup the job manager job_man_kwargs = self._setup_job_manager() @@ -1138,7 +1147,7 @@ def poll(self) -> None: # get settings from the task for triggering post-finished status checks check_completeness = self._get_task_attribute("check_job_completeness")() check_completeness_delay = self._get_task_attribute("check_job_completeness_delay")() - if check_completeness_delay: + if check_completeness and check_completeness_delay > 0: time.sleep(check_completeness_delay) # store jobs per status and take further actions depending on the status diff --git a/tests/typecheck.sh b/tests/typecheck.sh index a5a7a9a9..b8a187a5 100755 --- a/tests/typecheck.sh +++ b/tests/typecheck.sh @@ -11,9 +11,7 @@ action() { local repo_dir="$( dirname "${this_dir}" )" # default test command - # local cmd="${1:-mypy law tests}" - # temporary change: use the list of already polished parts - local cmd="${1:-mypy law/*.py law/cli law/job law/sandbox law/task law/workflow law/target law/contrib/awkward law/contrib/coffea law/contrib/git law/contrib/hdf5 law/contrib/ipython law/contrib/keras law/contrib/matplotlib law/contrib/mercurial law/contrib/numpy law/contrib/profiling law/contrib/pyarrow law/contrib/rich law/contrib/root law/contrib/slack law/contrib/telegram law/contrib/tensorflow law/contrib/wlcg law/contrib/gfal law/contrib/singularity law/contrib/docker law/contrib/dropbox law/contrib/tasks tests}" + local cmd="${1:-mypy law tests}" # execute it echo -e "command: \x1b[1;49;39m${cmd}\x1b[0m"