Skip to content

Commit

Permalink
Polish remaining contrib packages.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Jan 30, 2024
1 parent 3223d42 commit 1fafd55
Show file tree
Hide file tree
Showing 38 changed files with 1,868 additions and 1,080 deletions.
1 change: 0 additions & 1 deletion law/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

if sys.version_info[:2] >= (3, 9):
from types import GenericAlias # noqa

else:
GenericAlias = str

Expand Down
10 changes: 5 additions & 5 deletions law/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import luigi # type: ignore[import-untyped]

from law.config import Config
from law.task.base import Register, Task, ExternalTask
from law.task.base import Task, ExternalTask
from law.util import multi_match, colored, abort, makedirs, brace_expand
from law.logger import get_logger
from law._types import Sequence
from law._types import Sequence, Type


logger = get_logger(__name__)
Expand Down Expand Up @@ -151,9 +151,9 @@ def execute(args: argparse.Namespace) -> int:
# determine tasks to write into the index file
seen_families = []
task_classes = []
lookup: list[Register] = [Task]
lookup: list[Type[Task]] = [Task]
while lookup:
cls: Register = lookup.pop(0) # type: ignore
cls: Type[Task] = lookup.pop(0) # type: ignore
lookup.extend(cls.__subclasses__())

# skip tasks in __main__ module in interactive sessions
Expand Down Expand Up @@ -214,7 +214,7 @@ def execute(args: argparse.Namespace) -> int:

task_classes.append(cls)

def get_task_params(cls: Register) -> list[str]:
def get_task_params(cls) -> list[str]:
params = []
for attr in dir(cls):
member = getattr(cls, attr)
Expand Down
3 changes: 3 additions & 0 deletions law/contrib/arc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
"ensure_arcproxy",
]

# dependencies to other contrib modules
import law
law.contrib.load("wlcg")

# provisioning imports
from law.contrib.arc.util import (
Expand Down
22 changes: 16 additions & 6 deletions law/contrib/arc/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,41 @@
Decorators for task methods for convenient working with ARC.
"""

__all__ = ["ensure_arcproxy"]
from __future__ import annotations

__all__ = ["ensure_arcproxy"]

from law.task.base import Task
from law.decorator import factory
from law._types import Any, Callable

from law.contrib.arc import check_arcproxy_validity


@factory(accept_generator=True)
def ensure_arcproxy(fn, opts, task, *args, **kwargs):
def ensure_arcproxy(
fn: Callable,
opts: dict[str, Any],
task: Task,
*args,
**kwargs,
) -> tuple[Callable, Callable, Callable]:
"""
Decorator for law task methods that checks the validity of the arc proxy and throws an
exception in case it is invalid. This can prevent late errors on remote worker notes that except
arc proxies to be present. Accepts generator functions.
"""
def before_call():
def before_call() -> None:
# check the proxy validity
if not check_arcproxy_validity():
raise Exception("arc proxy not valid")

return None

def call(state):
def call(state: None) -> Any:
return fn(task, *args, **kwargs)

def after_call(state):
return
def after_call(state: None) -> None:
return None

return before_call, call, after_call
Loading

0 comments on commit 1fafd55

Please sign in to comment.