Skip to content

Commit

Permalink
Remove Slurm version argument
Browse files Browse the repository at this point in the history
Use an internal constant, make job scheduler's REST methods internal, consistent and log errors, and strip any newlines from token
  • Loading branch information
PeterC-DLS committed Oct 4, 2023
1 parent 47132e1 commit fe7630c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 65 deletions.
4 changes: 0 additions & 4 deletions ParProcCo/job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
output_dir_or_file: Path,
partition: str,
extra_properties: Optional[dict[str,str]] = None,
version: str = "v0.0.38",
user_name: Optional[str] = None,
user_token: Optional[str] = None,
timeout: timedelta = timedelta(hours=2),
Expand Down Expand Up @@ -57,7 +56,6 @@ def __init__(
self.working_directory = self.cluster_output_dir
logging.debug("JC working dir: %s", self.working_directory)
self.data_slicer: SlicerInterface
self.version = version
self.user_name = user_name if user_name else get_user()
self.user_token = user_token if user_token else get_slurm_token()
self.timeout = timeout
Expand Down Expand Up @@ -125,7 +123,6 @@ def _submit_sliced_jobs(
self.partition,
self.extra_properties,
self.timeout,
self.version,
self.user_name,
self.user_token,
)
Expand Down Expand Up @@ -167,7 +164,6 @@ def _submit_aggregation_job(self, memory: int) -> None:
self.partition,
self.extra_properties,
self.timeout,
self.version,
self.user_name,
self.user_token,
)
Expand Down
69 changes: 33 additions & 36 deletions ParProcCo/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
JobSubmissionResponse,
)

# Migrating from drmaa2 to slurm as in https://github.com/DiamondLightSource/python-zocalo

_SLURM_VERSION = "v0.0.38"


class SLURMSTATE(Enum):
Expand Down Expand Up @@ -139,9 +140,8 @@ def __init__(
working_directory: Optional[Union[Path, str]],
cluster_output_dir: Optional[Union[Path, str]],
partition: str,
extra_properties: Optional[dict[str,str]] = None,
extra_properties: Optional[dict[str, str]] = None,
timeout: timedelta = timedelta(hours=2),
version: str = "v0.0.38",
user_name: Optional[str] = None,
user_token: Optional[str] = None,
):
Expand Down Expand Up @@ -171,9 +171,7 @@ def __init__(
self.memory: int
self.cores: int
self.job_name: str
self._url = url
self._version = version
self._slurm_endpoint_prefix = f"slurm/{self._version}"
self._slurm_endpoint_url = f"{url}/slurm/{_SLURM_VERSION}"
self._session = requests.Session()

self.user = user_name if user_name else get_user()
Expand All @@ -182,46 +180,44 @@ def __init__(
self._session.headers["X-SLURM-USER-TOKEN"] = self.token
self._session.headers["Content-Type"] = "application/json"

def get(
def _get(
self,
endpoint: str,
params: dict[str, Any] | None = None,
timeout: float | None = None,
) -> requests.Response:
response = self._session.get(
f"{self._url}/{endpoint}", params=params, timeout=timeout
return self._session.get(
f"{self._slurm_endpoint_url}/{endpoint}", params=params, timeout=timeout
)
response.raise_for_status()
return response

def _post(self, data: BaseModel, endpoint) -> requests.Response:
url = f"{self._url}/{endpoint}"
resp = self._session.post(
url=url,
data=data.model_dump_json(exclude_defaults=True),
def _post(self, endpoint: str, data: BaseModel) -> requests.Response:
return self._session.post(
f"{self._slurm_endpoint_url}/{endpoint}",
data.model_dump_json(exclude_defaults=True),
)
return resp

def delete(
def _delete(
self,
endpoint: str,
params: dict[str, Any] | None = None,
timeout: float | None = None,
) -> requests.Response:
response = self._session.delete(
f"{self._url}/{endpoint}", params=params, timeout=timeout
return self._session.delete(
f"{self._slurm_endpoint_url}/{endpoint}", params=params, timeout=timeout
)

def _get_response_json(self, response: requests.Response) -> dict:
response.raise_for_status()
return response
try:
return response.json()
except:
logging.error("Response not json: %s", response.content, exc_info=True)
raise

def get_jobs_response(self, job_id: int | None = None) -> JobsResponse:
endpoint = (
f"{self._slurm_endpoint_prefix}/job/{job_id}"
if job_id is not None
else f"{self._slurm_endpoint_prefix}/jobs"
)
response = self.get(endpoint)
return JobsResponse.model_validate(response.json())
endpoint = f"job/{job_id}" if job_id is not None else "jobs"
response = self._get(endpoint)
return JobsResponse.model_validate(self._get_response_json(response))

def get_job(self, job_id: int) -> JobResponseProperties:
ji = self.get_jobs_response(job_id)
Expand Down Expand Up @@ -270,14 +266,12 @@ def update_status_infos(self, job_info: JobResponseProperties) -> None:
logging.info(f"Updating current state of {job_id} to {state}")

def submit_job(self, job_submission: JobSubmission) -> JobSubmissionResponse:
endpoint = f"{self._slurm_endpoint_prefix}/job/submit"
response = self._post(data=job_submission, endpoint=endpoint)
return JobSubmissionResponse.model_validate(response.json())
response = self._post("job/submit", job_submission)
return JobSubmissionResponse.model_validate(self._get_response_json(response))

def cancel_job(self, job_id: int) -> JobsResponse:
endpoint = f"{self._slurm_endpoint_prefix}/job/{job_id}"
response = self.delete(endpoint)
return JobsResponse.model_validate(response.json())
response = self._delete(f"job/{job_id}")
return JobsResponse.model_validate(self._get_response_json(response))

def get_output_paths(self) -> List[Path]:
return self.output_paths
Expand Down Expand Up @@ -396,7 +390,7 @@ def make_job_submission(self, i: int, job=None, jobs=None) -> JobSubmission:
get_user_environment="10L",
)
if self.extra_properties:
for k,v in self.extra_properties.items():
for k, v in self.extra_properties.items():
setattr(job, k, v)

return JobSubmission(script=self.jobscript_command, job=job)
Expand Down Expand Up @@ -424,7 +418,10 @@ def wait_all_jobs(
for job_id in list(remaining_jobs):
self.fetch_and_update_state(job_id)
# check if current_state is in STATEGROUP.ENDED or not in STATEGROUP.STARTING
if (self.status_infos[job_id].current_state in STATEGROUP[required_state]) == ended:
if (
self.status_infos[job_id].current_state
in STATEGROUP[required_state]
) == ended:
remaining_jobs.remove(job_id)
if len(remaining_jobs) > 0:
time.sleep(sleep_time)
Expand Down
45 changes: 23 additions & 22 deletions ParProcCo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@

import logging
import os
import yaml

from dataclasses import dataclass
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Union
from yaml import YAMLObject, SafeLoader

import yaml
from dataclasses import dataclass
from yaml import SafeLoader, YAMLObject

if sys.version_info < (3, 10):
from backports.entry_points_selectable import (
entry_points, # @UnresolvedImport @UnusedImport
)
else:
from importlib.metadata import entry_points # @UnresolvedImport @Reimport


def check_jobscript_is_readable(jobscript: Path) -> Path:
Expand Down Expand Up @@ -51,7 +59,7 @@ def get_filepath_on_path(filename: Optional[str]) -> Optional[Path]:


def get_slurm_token() -> str:
return os.environ["SLURM_JWT"]
return os.environ["SLURM_JWT"].strip()


def get_user() -> str:
Expand Down Expand Up @@ -110,7 +118,9 @@ class PPCConfig(YAMLObject):

allowed_programs: Dict[str, str] # program name, python package with wrapper module
url: str # slurm rest url
extra_property_envs: Optional[Dict[str, str]] = None # mapping of extra properties to environment variables to pass to slurm's JobProperties
extra_property_envs: Optional[
Dict[str, str]
] = None # mapping of extra properties to environment variables to pass to slurm's JobProperties


PPC_YAML = "par_proc_co.yaml"
Expand All @@ -136,22 +146,22 @@ def load_cfg() -> PPCConfig:


def get_token(filepath: str | None) -> str:
token = ""
if filepath is None:
try:
token = get_slurm_token()
return get_slurm_token()
except KeyError:
raise ValueError(
"No slurm token found. No slurm token filepath provided and no environment variable 'SLURM_JWT'"
)
else:
if os.path.isfile(filepath):
with open(filepath) as f:
token = f.read()

token = ""
if os.path.isfile(filepath):
with open(filepath) as f:
token = f.read().strip()

if token != "":
return token
raise FileNotFoundError(f"Slurm token file f{filepath} not found")
raise FileNotFoundError(f"Slurm token file f{filepath} not found or is empty")


def set_up_wrapper(cfg: PPCConfig, program: str):
Expand All @@ -160,15 +170,6 @@ def set_up_wrapper(cfg: PPCConfig, program: str):
logging.info(f"{program} on allowed list in {cfg}")
package = allowed[program]
else:
import sys

if sys.version_info < (3, 10):
from backports.entry_points_selectable import (
entry_points, # @UnresolvedImport
)
else:
from importlib.metadata import entry_points # @UnresolvedImport

logging.info(f"Checking entry points for {program}")
eps = entry_points(group=PPC_ENTRY_POINT)
try:
Expand Down
3 changes: 0 additions & 3 deletions scripts/ppc_cluster_submit
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ from ParProcCo.job_controller import JobController
from ParProcCo.passthru_wrapper import PassThruWrapper
from ParProcCo.utils import get_token, set_up_wrapper

SLURM_VERSION = "0.0.38"


def create_parser() -> argparse.ArgumentParser:
"""
Expand Down Expand Up @@ -124,7 +122,6 @@ def run_ppc(args: argparse.Namespace, script_args: List) -> None:
output,
args.partition,
extra_properties,
SLURM_VERSION,
user,
token,
timeout,
Expand Down

0 comments on commit fe7630c

Please sign in to comment.