Skip to content

Commit

Permalink
Remove cluster config
Browse files Browse the repository at this point in the history
Add extra property environments that allow Slurm's fields in job properties to be set from environment variables
  • Loading branch information
PeterC-DLS committed Oct 2, 2023
1 parent ced34c9 commit 8db37bd
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 47 deletions.
3 changes: 3 additions & 0 deletions ParProcCo/job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
program_wrapper: ProgramWrapper,
output_dir_or_file: Path,
partition: str,
extra_properties: Optional[dict[str,str]] = None,
version: str = "v0.0.38",
user_name: Optional[str] = None,
user_token: Optional[str] = None,
Expand All @@ -28,6 +29,7 @@ def __init__(
self.url = url
self.program_wrapper = program_wrapper
self.partition = partition
self.extra_properties = extra_properties
self.output_file: Optional[Path] = None
self.cluster_output_dir: Optional[Path] = None

Expand Down Expand Up @@ -121,6 +123,7 @@ def _submit_sliced_jobs(
self.working_directory,
self.cluster_output_dir,
self.partition,
self.extra_properties,
self.timeout,
self.version,
self.user_name,
Expand Down
5 changes: 5 additions & 0 deletions ParProcCo/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
working_directory: Optional[Union[Path, str]],
cluster_output_dir: Optional[Union[Path, str]],
partition: str,
extra_properties: Optional[dict[str,str]] = None,
timeout: timedelta = timedelta(hours=2),
version: str = "v0.0.38",
user_name: Optional[str] = None,
Expand All @@ -165,6 +166,7 @@ def __init__(
else (self.cluster_output_dir if self.cluster_output_dir else Path.home())
)
self.partition = partition
self.extra_properties = extra_properties
self.scheduler_mode: SchedulerModeInterface
self.memory: int
self.cores: int
Expand Down Expand Up @@ -406,6 +408,9 @@ def make_job_submission(self, i: int, job=None, jobs=None) -> JobSubmission:
standard_error=stderr_fp,
get_user_environment="10L",
)
if self.extra_properties:
for k,v in self.extra_properties.items():
setattr(job, k, v)

return JobSubmission(script=self.jobscript_command, job=job)

Expand Down
38 changes: 7 additions & 31 deletions ParProcCo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Literal, Optional, Set, Union
from typing import Dict, Optional, Union
from yaml import YAMLObject, SafeLoader


Expand Down Expand Up @@ -103,26 +103,13 @@ def slice_to_string(s: Optional[slice]) -> str:
return f"{start}:{stop}:{step}"


@dataclass
class PPCCluster(YAMLObject):
yaml_tag = "!PPCCluster"
yaml_loader = SafeLoader

module: str # module loaded to submit jobs
default_queue: str # default cluster queue
user_queues: Optional[
Dict[str, List[str]]
] = None # specific queues with allowed users
resources: Optional[Dict[str, str]] = None # job resources


@dataclass
class PPCConfig(YAMLObject):
yaml_tag = "!PPCConfig"
yaml_loader = SafeLoader

allowed_programs: Dict[str, str] # program name, python package with wrapper module
project_env_var: str # name of environment variable holding project passed to qsub
extra_property_envs: Optional[Dict[str, str]] # mapping of extra properties to environment variables to pass to slurm's JobProperties
url: str # slurm rest url


Expand All @@ -142,17 +129,6 @@ def load_cfg() -> PPCConfig:
with open(cfg, "r") as cff:
ppc_config = yaml.safe_load(cff)

for ccfg in ppc_config.clusters.values():
if ccfg.user_queues:
users: Set[str] = set() # check for overlaps
for us in ccfg.user_queues.values():
common = users.intersection(set(us))
if common:
raise ValueError(
"Users %s cannot be assigned to more than one queue",
", ".join(common),
)
users.update(us)
return ppc_config


Expand Down Expand Up @@ -188,8 +164,8 @@ def set_up_wrapper(cfg: PPCConfig, program: str):

if sys.version_info < (3, 10):
from backports.entry_points_selectable import (
entry_points,
) # @UnresolvedImport
entry_points, # @UnresolvedImport
)
else:
from importlib.metadata import entry_points # @UnresolvedImport

Expand Down Expand Up @@ -220,9 +196,9 @@ def set_up_wrapper(cfg: PPCConfig, program: str):

def find_cfg_file(name: str) -> Path:
""" """
cp = os.getenv("PPC_CONFIG")
if cp:
return Path(cp)
c = os.getenv("PPC_CONFIG")
if c:
return Path(c)

cp = Path.home() / ("." + name)
if cp.is_file():
Expand Down
14 changes: 2 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,8 @@ allowed_programs:
rs_map: msmapper_utils
blah1: whatever_package1
blah2: whatever_package2
project_env_var: CLUSTER_PROJECT
cluster_help_msg: Please module load blah
clusters:
cluster_one: !PPCCluster
default_queue: basic.q
user_queues:
better.q: middle_user1
best.q: power_user1, power_user2
cluster_two: !PPCCluster
default_queue: only.q
resources:
cpu_model: arm64
extra_property_envs: # optional dictionary for slurm job properties and environment variables
account: MY_ACCOUNT # env var that holds account
```

An entry point called `ParProcCo.allowed_programs` can be added to other packages' `setup.py`:
Expand Down
20 changes: 16 additions & 4 deletions scripts/ppc_cluster_submit
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def create_parser() -> argparse.ArgumentParser:
default="2h",
)
parser.add_argument(
"--memory", help="str: memory to use per cluster job", required=True
"--memory",
help="int: memory to use per cluster job (MB)",
type=int,
required=True
)
parser.add_argument(
"--cores",
Expand Down Expand Up @@ -87,9 +90,17 @@ def run_ppc(args: argparse.Namespace, script_args: List) -> None:
cfg = load_cfg()
url = cfg.url
token = get_token(args.token)
partition = args.partition
user = getuser()
partition = args.partition

extra_properties = None
if cfg.extra_property_envs:
extra_properties = {}
logging.debug("Extra job properties:")
for k,e in cfg.extra_property_envs:
v = os.getenv(e)
logging.debug("\t%s: %s", k, v)
if v:
extra_properties[k] = v

timeout = parse_timeout(args.timeout)
logging.info("Running for with timeout %s", timeout)
Expand All @@ -111,7 +122,8 @@ def run_ppc(args: argparse.Namespace, script_args: List) -> None:
url,
wrapper,
output,
partition,
args.partition,
extra_properties,
SLURM_VERSION,
user,
token,
Expand Down

0 comments on commit 8db37bd

Please sign in to comment.