Skip to content

Commit

Permalink
Fix some bugs and type annotations from mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterC-DLS committed Jan 26, 2024
1 parent a31d6b5 commit 7a1af60
Show file tree
Hide file tree
Showing 14 changed files with 122 additions and 182 deletions.
32 changes: 19 additions & 13 deletions ParProcCo/job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,20 @@ def __init__(
self.user_name = user_name
self.user_token = user_token
self.timeout = timeout
self.sliced_results: list[Path] | None = None
self.sliced_results: tuple[Path, ...] | None = None
self.aggregated_result: Path | None = None

def run(
self,
number_jobs: int,
jobscript_args: list | None = None,
jobscript_args: list[str] | None = None,
job_name: str = "ParProcCo",
processing_job_resources: JobResources | None = None,
aggregation_job_resources: JobResources | None = None,
) -> None:
self.cluster_runner = self.program_wrapper.get_process_script()
if self.cluster_runner is None:
raise ValueError("Processing script must be defined")
if processing_job_resources is None:
processing_job_resources = JobResources()
if aggregation_job_resources is None:
Expand All @@ -85,11 +87,12 @@ def run(
logging.debug("Cluster environment is %s", self.cluster_env)

timestamp = datetime.now()
jobscript_args[0] = str(
check_jobscript_is_readable(
check_location(get_absolute_path(jobscript_args[0]))
if jobscript_args:
jobscript_args[0] = str(
check_jobscript_is_readable(
check_location(get_absolute_path(jobscript_args[0]))
)
)
)
sliced_jobs_success = self._submit_sliced_jobs(
number_jobs,
jobscript_args,
Expand All @@ -100,9 +103,10 @@ def run(

if sliced_jobs_success and self.sliced_results:
logging.info("Sliced jobs ran successfully.")
out_file: Path | None = None
if number_jobs == 1:
out_file = (
self.sliced_results[0] if len(self.sliced_results) > 0 else None
self.sliced_results[0] if self.sliced_results else None
)
else:
self._submit_aggregation_job(aggregation_job_resources, timestamp)
Expand All @@ -128,27 +132,28 @@ def run(
def _submit_sliced_jobs(
self,
number_of_jobs: int,
jobscript_args: list | None,
jobscript_args: list[str] | None,
job_resources: JobResources,
job_name: str,
timestamp: datetime,
) -> bool:
if jobscript_args is None:
jobscript_args = []

assert self.cluster_runner
jsi = JobSchedulingInformation(
job_name=job_name,
job_script_path=self.cluster_runner,
job_resources=job_resources,
timeout=self.timeout,
job_script_arguments=jobscript_args,
job_env=self.cluster_env,
job_script_arguments=tuple(jobscript_args),
working_directory=self.working_directory,
output_dir=self.output_file.parent if self.output_file else None,
output_filename=self.output_file.name if self.output_file else None,
log_directory=self.cluster_output_dir,
timestamp=timestamp,
)
jsi.set_job_env(self.cluster_env)

job_scheduler = JobScheduler(
url=self.url,
Expand Down Expand Up @@ -186,19 +191,20 @@ def _submit_aggregation_job(
if aggregator_path is not None:
aggregation_args.append(str(aggregator_path))

assert self.sliced_results is not None and self.cluster_runner
jsi = JobSchedulingInformation(
job_name=aggregating_slicer.__class__.__name__,
job_script_path=self.cluster_runner,
job_resources=job_resources,
job_script_arguments=aggregation_args,
job_env=self.cluster_env,
job_script_arguments=tuple(aggregation_args),
working_directory=self.working_directory,
timeout=timedelta(seconds=AGGREGATION_TIME * len(self.sliced_results)),
output_dir=self.output_file.parent if self.output_file else None,
output_filename=self.output_file.name if self.output_file else None,
log_directory=self.cluster_output_dir,
timestamp=timestamp,
)
jsi.set_job_env(self.cluster_env)

aggregation_scheduler = JobScheduler(
url=self.url,
Expand All @@ -209,7 +215,7 @@ def _submit_aggregation_job(
)

aggregation_jobs = aggregating_slicer.create_slice_jobs(
jsi, self.sliced_results
jsi, list(self.sliced_results)
)

aggregation_success = aggregation_scheduler.run(aggregation_jobs)
Expand Down
69 changes: 38 additions & 31 deletions ParProcCo/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
import time
from collections.abc import Sequence
from collections.abc import Sequence, ValuesView
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand Down Expand Up @@ -80,38 +80,38 @@ class SLURMSTATE(Enum):
"Custom state. Output file has not been updated since job started."


class STATEGROUP(set, Enum):
OUTOFTIME = {SLURMSTATE.TIMEOUT, SLURMSTATE.DEADLINE}
FINISHED = {
class STATEGROUP(tuple[SLURMSTATE], Enum):
OUTOFTIME = (SLURMSTATE.TIMEOUT, SLURMSTATE.DEADLINE)
FINISHED = (
SLURMSTATE.COMPLETED,
SLURMSTATE.FAILED,
SLURMSTATE.TIMEOUT,
SLURMSTATE.DEADLINE,
}
COMPUTEISSUE = {
)
COMPUTEISSUE = (
SLURMSTATE.BOOT_FAIL,
SLURMSTATE.NODE_FAIL,
SLURMSTATE.OUT_OF_MEMORY,
}
ENDED = {
)
ENDED = (
SLURMSTATE.COMPLETED,
SLURMSTATE.FAILED,
SLURMSTATE.TIMEOUT,
SLURMSTATE.DEADLINE,
}
REQUEUEABLE = {
)
REQUEUEABLE = (
SLURMSTATE.CONFIGURING,
SLURMSTATE.RUNNING,
SLURMSTATE.STOPPED,
SLURMSTATE.SUSPENDED,
}
STARTING = {
)
STARTING = (
SLURMSTATE.PENDING,
SLURMSTATE.REQUEUED,
SLURMSTATE.RESIZING,
SLURMSTATE.SUSPENDED,
SLURMSTATE.CONFIGURING,
}
)


@dataclass
Expand Down Expand Up @@ -197,7 +197,7 @@ def fetch_and_update_state(
wall_time = None

status_info = job_scheduling_info.status_info

assert status_info
if submit_time:
# Don't overwrite unless a more specific value is given by the scheduler
status_info.submit_time = datetime.fromtimestamp(submit_time)
Expand All @@ -212,20 +212,21 @@ def fetch_and_update_state(
return slurm_state

def get_output_paths(
self, job_scheduling_info_list: list[JobSchedulingInformation]
) -> tuple[Path]:
return tuple(
self, job_scheduling_info_list: list[JobSchedulingInformation] | ValuesView[JobSchedulingInformation]
) -> tuple[Path, ...]:
return tuple(p for p in (
jsi.get_output_path()
for jsi in job_scheduling_info_list
if jsi.get_output_path() is not None
)
) if p)

def get_success(
self, job_scheduling_info_list: list[JobSchedulingInformation]
) -> bool:
return all((info.completion_status for info in job_scheduling_info_list))

def timestamp_ok(self, output: Path, start_time: datetime) -> bool:
def timestamp_ok(self, output: Path, start_time: datetime | None) -> bool:
if start_time is None:
return False
mod_time = datetime.fromtimestamp(output.stat().st_mtime)
return mod_time > start_time

Expand Down Expand Up @@ -268,6 +269,7 @@ def _submit_jobs(
f" and args {job_scheduling_info.job_script_arguments}"
)
submission = self.make_job_submission(job_scheduling_info)
assert submission.job is not None
resp = self.client.submit_job(submission)
if resp.job_id is None:
resp = self.client.submit_job(submission)
Expand Down Expand Up @@ -301,16 +303,19 @@ def make_job_submission(

error_dir = self.cluster_output_dir / "cluster_logs"
else:
assert job_scheduling_info.working_directory
error_dir = job_scheduling_info.working_directory / "cluster_logs"
job_scheduling_info.log_directory = error_dir

if not job_scheduling_info.log_directory.is_dir():
logging.debug(f"Making directory {job_scheduling_info.log_directory}")
job_scheduling_info.log_directory.mkdir(exist_ok=True, parents=True)
else:
assert job_scheduling_info.log_directory
logging.debug(
f"Directory {job_scheduling_info.log_directory} already exists"
)
assert job_scheduling_info.job_script_path
job_script_path = check_jobscript_is_readable(
job_scheduling_info.job_script_path
)
Expand Down Expand Up @@ -341,7 +346,7 @@ def make_job_submission(

def wait_all_jobs(
self,
job_scheduling_info_list: list[JobSchedulingInformation],
job_scheduling_info_list: Sequence[JobSchedulingInformation] | ValuesView[JobSchedulingInformation],
state_group: STATEGROUP,
deadline: datetime,
sleep_time: int,
Expand Down Expand Up @@ -384,7 +389,7 @@ def get_deadline(
)

def handle_not_started(
job_scheduling_info_list: list[JobSchedulingInformation],
job_scheduling_info_list: Sequence[JobSchedulingInformation] | ValuesView[JobSchedulingInformation],
check_time: timedelta,
) -> list[JobSchedulingInformation]:
# Wait for jobs to start (timeout shouldn't include queue time)
Expand All @@ -405,7 +410,7 @@ def handle_not_started(
return starting_jobs

def wait_for_ended(
job_scheduling_info_list: Sequence[JobSchedulingInformation],
job_scheduling_info_list: Sequence[JobSchedulingInformation] | ValuesView[JobSchedulingInformation],
deadline: datetime,
check_time: timedelta,
) -> list[JobSchedulingInformation]:
Expand All @@ -425,18 +430,19 @@ def wait_for_ended(
return ended_jobs

def handle_ended_jobs(
job_scheduling_info_list: Sequence[JobSchedulingInformation],
job_scheduling_info_list: Sequence[JobSchedulingInformation] | ValuesView[JobSchedulingInformation],
) -> list[JobSchedulingInformation]:
ended_jobs = []
for job_scheduling_info in job_scheduling_info_list:
self.fetch_and_update_state(job_scheduling_info)
assert job_scheduling_info.status_info
if job_scheduling_info.status_info.current_state in STATEGROUP.ENDED:
logging.debug("Removing ended %d", job_scheduling_info.job_id)
ended_jobs.append(job_scheduling_info)
return ended_jobs

def handle_timeouts(
job_scheduling_info_list: Sequence[JobSchedulingInformation],
job_scheduling_info_list: Sequence[JobSchedulingInformation] | ValuesView[JobSchedulingInformation],
) -> list[JobSchedulingInformation]:
deadlines = (
(jsi, get_deadline(jsi, allow_from_submission=False))
Expand Down Expand Up @@ -497,7 +503,7 @@ def handle_timeouts(
for deadline in (
get_deadline(jsi, allow_from_submission=True)
for jsi in running_jobs.values()
)
) if deadline is not None
]
)
check_time = min(
Expand All @@ -524,7 +530,7 @@ def handle_timeouts(
logging.debug("_wait_for_jobs loop ending, starting clear-up")

if terminate_after_wait:
for jsi in list(running_jobs):
for jsi in running_jobs.values():
try:
logging.info(
"Waiting for jobs timed out. Terminating job %d now.",
Expand Down Expand Up @@ -556,6 +562,7 @@ def _report_job_info(
for job_scheduling_info in job_scheduling_info_list:
job_id = job_scheduling_info.job_id
status_info = job_scheduling_info.status_info
assert status_info
stdout_path = job_scheduling_info.get_stdout_path()
logging.debug(f"Retrieving info for job {job_id}")

Expand All @@ -581,7 +588,7 @@ def _report_job_info(

elif not self.timestamp_ok(
stdout_path,
start_time=job_scheduling_info.status_info.start_time,
start_time=status_info.start_time,
):
status_info.final_state = SLURMSTATE.OLD_OUTPUT_FILE
logging.error(
Expand Down Expand Up @@ -626,7 +633,7 @@ def resubmit_jobs(
new_job_scheduling_info = deepcopy(old_job_scheduling_info)
new_job_scheduling_info.set_completion_status(False)
new_job_scheduling_info.status_info = None
new_job_scheduling_info.job_id = None
new_job_scheduling_info.job_id = -1
new_job_scheduling_info_list.append(new_job_scheduling_info)
logging.info(f"Resubmitting jobs from batch {batch} with job_ids: {job_ids}")
return self._submit_and_monitor(new_job_scheduling_info_list)
Expand All @@ -637,7 +644,7 @@ def filter_killed_jobs(
return [
jsi
for jsi in job_scheduling_information_list
if jsi.status_info.current_state == SLURMSTATE.CANCELLED
if jsi.status_info and jsi.status_info.current_state == SLURMSTATE.CANCELLED
]

def resubmit_killed_jobs(
Expand All @@ -655,7 +662,7 @@ def resubmit_killed_jobs(
failed_jobs = [
jsi
for jsi in job_scheduling_info_dict.values()
if jsi.status_info.final_state != SLURMSTATE.COMPLETED
if jsi.status_info and jsi.status_info.final_state != SLURMSTATE.COMPLETED
]
killed_jobs = self.filter_killed_jobs(failed_jobs)
logging.info(
Expand Down
8 changes: 4 additions & 4 deletions ParProcCo/job_scheduling_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class JobResources:
@dataclass
class JobSchedulingInformation:
job_name: str
job_script_path: Path | None
job_script_path: Path
job_resources: JobResources
timeout: timedelta = timedelta(hours=2)
job_script_arguments: tuple[str] = field(default_factory=tuple)
job_script_arguments: tuple[str, ...] = field(default_factory=tuple)
job_env: dict[str, str] = field(default_factory=dict)
log_directory: Path | None = None
stderr_filename: str | None = None
Expand All @@ -38,11 +38,11 @@ class JobSchedulingInformation:
output_filename: str | None = None
timestamp: datetime | None = None

def __post_init__(self):
def __post_init__(self) -> None:
self.set_job_script_path(self.job_script_path) # For validation
self.set_job_env(self.job_env) # For validation
# To be updated when submitted, not on creation
self.job_id: int | None = None
self.job_id: int = -1
self.status_info: StatusInfo | None = None
self.completion_status: bool = False

Expand Down
4 changes: 2 additions & 2 deletions ParProcCo/job_slicer_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@


class JobSlicerInterface:
def __init__(self, job_script: Path | None) -> None:
def __init__(self, job_script: Path | None = None) -> None:
if job_script is not None:
self.job_script = check_jobscript_is_readable(
check_location(get_absolute_path(job_script))
)
else:
self.job_script = "n/a"
self.job_script = Path("n/a")
self.allowed_modules: tuple[str, ...] | None = None

def create_slice_jobs(
Expand Down
Loading

0 comments on commit 7a1af60

Please sign in to comment.