Skip to content

Commit

Permalink
Fix most egregious type errors from pyright
Browse files Browse the repository at this point in the history
Pass memory as int throughout, set default in some JobProperties fields, black reformat
  • Loading branch information
PeterC-DLS committed Sep 1, 2023
1 parent c7cc8fc commit 0779051
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 36 deletions.
6 changes: 3 additions & 3 deletions ParProcCo/job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run(
self,
number_jobs: int,
jobscript_args: Optional[List] = None,
memory: str = "4G",
memory: int = 4000,
job_name: str = "ParProcCo",
) -> None:
self.cluster_runner = check_location(
Expand Down Expand Up @@ -105,7 +105,7 @@ def _run_sliced_jobs(
self,
slice_params: List[Optional[slice]],
jobscript_args: Optional[List],
memory: str,
memory: int,
job_name: str,
) -> bool:
if jobscript_args is None:
Expand Down Expand Up @@ -141,7 +141,7 @@ def _run_sliced_jobs(
)
return sliced_jobs_success

def _run_aggregation_job(self, memory: str) -> None:
def _run_aggregation_job(self, memory: int) -> None:
aggregator_path = self.program_wrapper.get_aggregate_script()
aggregating_mode = self.program_wrapper.aggregating_mode
if aggregating_mode is None or self.sliced_results is None:
Expand Down
96 changes: 73 additions & 23 deletions ParProcCo/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from .scheduler_mode_interface import SchedulerModeInterface
from .utils import check_jobscript_is_readable
from models.slurmdb_rest import DbJob, DbJobInfo
from models.slurm_rest import JobProperties, JobsResponse, JobResponseProperties, JobSubmission, JobSubmissionResponse
from models.slurm_rest import (
JobProperties,
JobsResponse,
JobResponseProperties,
JobSubmission,
JobSubmissionResponse,
)

# WIP: Migrating from drmaa2 to slurm as in https://github.com/DiamondLightSource/python-zocalo

Expand Down Expand Up @@ -118,7 +124,7 @@ def __init__(
self.jobscript_args: List
self.output_paths: List[Path] = []
self.start_time = datetime.now()
self.status_infos: Dict[str:StatusInfo]
self.status_infos: Dict[int, StatusInfo]
self.timeout = timeout
self.working_directory = (
Path(working_directory)
Expand All @@ -132,28 +138,43 @@ def __init__(
self._url = url
self._version = version
self._session = requests.Session()
self._session.headers["X-SLURM-USER-NAME"] = user_name if user_name else os.environ["USER"]
self._session.headers["X-SLURM-USER-NAME"] = (
user_name if user_name else os.environ["USER"]
)
self.user = user_name if user_name else os.environ["USER"]
self.token = user_token if user_token else os.environ["SLURM_JWT"]
self._session.headers["X-SLURM-USER-TOKEN"] = user_token if user_token else os.environ["SLURM_JWT"]
self._session.headers["X-SLURM-USER-TOKEN"] = (
user_token if user_token else os.environ["SLURM_JWT"]
)

def get(self, endpoint: str, params: dict[str, Any] = None, timeout: float | None = None) -> requests.Response:
response = self._session.get(f"{self._url}/{endpoint}", params=params, timeout=timeout)
def get(
self,
endpoint: str,
params: dict[str, Any] | None = None,
timeout: float | None = None,
) -> requests.Response:
response = self._session.get(
f"{self._url}/{endpoint}", params=params, timeout=timeout
)
response.raise_for_status()
return response

def put(
self,
endpoint: str,
params: dict[str, Any] = None,
json: dict[str, Any] = None,
params: dict[str, Any] | None = None,
json: dict[str, Any] | None = None,
timeout: float | None = None,
) -> requests.Response:
response = self._session.put(f"{self._url}/{endpoint}", params=params, json=json, timeout=timeout)
response = self._session.put(
f"{self._url}/{endpoint}", params=params, json=json, timeout=timeout
)
response.raise_for_status()
return response

def _prepare_request(self, data: BaseModel) -> tuple[str, dict[str, str]] | tuple[None, None]:
def _prepare_request(
self, data: BaseModel
) -> tuple[str, dict[str, str]] | tuple[None, None]:
if data is None:
return None, None
return data.model_dump_json(exclude_defaults=True), {
Expand All @@ -168,8 +189,15 @@ def _post(self, data: BaseModel, endpoint):
resp = requests.post(url, data=jdata, headers=headers)
return resp

def delete(self, endpoint: str, params: dict[str, Any] = None, timeout: float | None = None) -> requests.Response:
response = self._session.delete(f"{self._url}/{endpoint}", params=params, timeout=timeout)
def delete(
self,
endpoint: str,
params: dict[str, Any] | None = None,
timeout: float | None = None,
) -> requests.Response:
response = self._session.delete(
f"{self._url}/{endpoint}", params=params, timeout=timeout
)
response.raise_for_status()
return response

Expand Down Expand Up @@ -206,15 +234,19 @@ def get_job_state(self, job: DbJob | JobResponseProperties) -> str:
elif isinstance(job, DbJob):
return job.state.current

def update_status_infos(self, job_id, job_info: DbJob | JobResponseProperties, state: str):
def update_status_infos(
self, job_id: int, job_info: DbJob | JobResponseProperties, state: str
):
if isinstance(job_info, JobResponseProperties):
try:
slots = int((job_info.tres_alloc_str.split(",")[1]).split("=")[1])
dispatch_time = job_info.start_time
time_to_dispatch = dispatch_time - job_info.submit_time
wall_time = job_info.end_time - dispatch_time
except Exception:
logging.error(f"Failed to get job submission time statistics for job {job_info}")
logging.error(
f"Failed to get job submission time statistics for job {job_info}"
)
raise
else:
try:
Expand All @@ -224,7 +256,9 @@ def update_status_infos(self, job_id, job_info: DbJob | JobResponseProperties, s
time_to_dispatch = dispatch_time - job_time.submission
wall_time = job_time.end - dispatch_time
except Exception:
logging.error(f"Failed to get job submission time statistics for job {job_info}")
logging.error(
f"Failed to get job submission time statistics for job {job_info}"
)
raise

self.status_infos[job_id].slots = slots
Expand Down Expand Up @@ -288,12 +322,16 @@ def _run_and_monitor(self, job_indices: List[int]) -> bool:
return self.get_success()

def _run_jobs(self, job_indices: List[int]) -> None:
logging.debug(f"Running jobs on cluster for jobscript {self.jobscript_path} and args {self.jobscript_args}")
logging.debug(
f"Running jobs on cluster for jobscript {self.jobscript_path} and args {self.jobscript_args}"
)
try:
self.status_infos = {}
for i in job_indices:
template = self.make_job_submission(i)
resp = self.submit_job(template)
if resp.job_id is None:
raise ValueError("Job submission failed", resp.errors)
self.status_infos[resp.job_id] = StatusInfo(
Path(template.job.standard_output),
i,
Expand Down Expand Up @@ -329,8 +367,12 @@ def make_job_submission(self, i: int, job=None, jobs=None) -> JobSubmission:
)
if output_fp and output_fp not in self.output_paths:
self.output_paths.append(Path(output_fp))
args = self.scheduler_mode.generate_args(i, self.memory, self.cores, self.jobscript_args, output_fp)
self.jobscript_command = " ".join([f"#!/bin/bash\n{self.jobscript_path}", *args])
args = self.scheduler_mode.generate_args(
i, self.memory, self.cores, self.jobscript_args, output_fp
)
self.jobscript_command = " ".join(
[f"#!/bin/bash\n{self.jobscript_path}", *args]
)
logging.info(f"creating template with command: {self.jobscript_command}")
job = JobProperties(
name=self.job_name,
Expand Down Expand Up @@ -419,7 +461,9 @@ def _wait_for_jobs(
# Termination takes some time, wait a max of 2 mins
self.wait_all_jobs_terminated(jobs_remaining, 120)
total_time += 120
logging.info(f"Jobs terminated = {len(jobs_remaining)} after {total_time}s")
logging.info(
f"Jobs terminated = {len(jobs_remaining)} after {total_time}s"
)
except Exception:
logging.error("Unknown error occurred running slurm job", exc_info=True)

Expand Down Expand Up @@ -482,9 +526,11 @@ def resubmit_jobs(self, job_indices: List[int]) -> bool:
logging.info(f"Resubmitting jobs with job_indices: {job_indices}")
return self._run_and_monitor(job_indices)

def filter_killed_jobs(self, jobs: Dict[str:StatusInfo]) -> Dict[str:StatusInfo]:
def filter_killed_jobs(self, jobs: Dict[str, StatusInfo]) -> Dict[str, StatusInfo]:
killed_jobs = {
job_id: status_info for job_id, status_info in jobs.items() if status_info.current_state == "CANCELLED"
job_id: status_info
for job_id, status_info in jobs.items()
if status_info.current_state == "CANCELLED"
}
return killed_jobs

Expand All @@ -496,11 +542,15 @@ def rerun_killed_jobs(self, allow_all_failed: bool = False):
return True
elif allow_all_failed or any(self.job_completion_status.values()):
failed_jobs = {
job_id: job_info for job_id, job_info in job_history[0].items() if job_info.final_state != "COMPLETED"
job_id: job_info
for job_id, job_info in job_history[0].items()
if job_info.final_state != "COMPLETED"
}
killed_jobs = self.filter_killed_jobs(failed_jobs)
killed_jobs_indices = [job.i for job in killed_jobs.values()]
logging.info(f"Total failed_jobs: {len(failed_jobs)}. Total killed_jobs: {len(killed_jobs)}")
logging.info(
f"Total failed_jobs: {len(failed_jobs)}. Total killed_jobs: {len(killed_jobs)}"
)
if killed_jobs_indices:
return self.resubmit_jobs(killed_jobs_indices)
return True
Expand Down
2 changes: 1 addition & 1 deletion ParProcCo/nxdata_aggregation_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def generate_output_paths(
def generate_args(
self,
i: int,
_memory: str,
_memory: int,
_cores: int,
jobscript_args: List[str],
output_fp: str,
Expand Down
4 changes: 2 additions & 2 deletions ParProcCo/passthru_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def generate_output_paths(
return str(output_dir) if output_dir else "", stdout_fp, stderr_fp

def generate_args(
self, i: int, memory: str, cores: int, jobscript_args: List[str], output_fp: str
self, i: int, memory: int, cores: int, jobscript_args: List[str], output_fp: str
) -> Tuple[str, ...]:
"""Overrides SchedulerModeInterface.generate_args"""
assert i < self.number_jobs
Expand All @@ -43,7 +43,7 @@ def generate_args(
check_location(get_absolute_path(jobscript_args[0]))
)
)
args = [jobscript, "--memory", memory, "--cores", str(cores)]
args = [jobscript, "--memory", str(memory), "--cores", str(cores)]
if output_fp:
args += ("--output", output_fp)
args += jobscript_args[1:]
Expand Down
2 changes: 1 addition & 1 deletion ParProcCo/scheduler_mode_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def generate_output_paths(
def generate_args(
self,
job_number: int,
memory: str,
memory: int,
cores: int,
jobscript_args: List[str],
output_fp: str,
Expand Down
4 changes: 2 additions & 2 deletions models/slurm_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class JobProperties(BaseModel):
"""
Do not automatically terminate a job if one of the nodes it has been allocated fails.
"""
nodes: list[int] | None = Field(None, max_length=2, min_length=1)
nodes: list[int] | None = Field(default=None, max_length=2, min_length=1)
"""
Request that a minimum of minnodes nodes and a maximum node count.
"""
Expand Down Expand Up @@ -340,7 +340,7 @@ class JobProperties(BaseModel):
"""
Allocate resources for the job from the named reservation.
"""
signal: str | None = Field(None, pattern="(B:|)sig_num(@sig_time|)")
signal: str | None = Field(default=None, pattern="(B:|)sig_num(@sig_time|)")
"""
When a job is within sig_time seconds of its end time, send it the signal sig_num.
"""
Expand Down
16 changes: 12 additions & 4 deletions tests/test_job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def tearDown(self):
os.rmdir(self.base_dir)

def test_create_job_scheduler(self) -> None:
with TemporaryDirectory(prefix="test_dir_", dir=self.base_dir) as working_directory:
with TemporaryDirectory(
prefix="test_dir_", dir=self.base_dir
) as working_directory:
cluster_output_dir = Path(working_directory) / "cluster_output_dir"
js = create_js(working_directory, cluster_output_dir)
self.assertTrue(
Expand All @@ -49,7 +51,9 @@ def test_create_job_scheduler(self) -> None:
)

def test_create_job_submission(self) -> None:
with TemporaryDirectory(prefix="test_dir_", dir=self.base_dir) as working_directory:
with TemporaryDirectory(
prefix="test_dir_", dir=self.base_dir
) as working_directory:
input_path = Path("path/to/file.extension")
cluster_output_dir = Path(working_directory) / "cluster_output_dir"
scheduler = create_js(working_directory, cluster_output_dir)
Expand Down Expand Up @@ -247,7 +251,9 @@ def test_job_times_out(self) -> None:
)
self.assertEqual(len(context.output), 8)
for warn_msg in context.output[:4]:
self.assertTrue(warn_msg.endswith(" timed out. Terminating job now."))
self.assertTrue(
warn_msg.endswith(" timed out. Terminating job now.")
)
for err_msg in context.output[4:]:
self.assertTrue("has not created output file" in err_msg)

Expand Down Expand Up @@ -369,7 +375,9 @@ def test_timestamp_ok_true(self, name, run_scheduler_last) -> None:
self.assertEqual(js.timestamp_ok(filepath), run_scheduler_last)

def test_get_jobs(self) -> None:
with TemporaryDirectory(prefix="test_dir_", dir=self.base_dir) as working_directory:
with TemporaryDirectory(
prefix="test_dir_", dir=self.base_dir
) as working_directory:
cluster_output_dir = Path(working_directory) / "cluster_output_dir"
js = create_js(working_directory, cluster_output_dir)
jobs = js.get_jobs()
Expand Down

0 comments on commit 0779051

Please sign in to comment.