Skip to content

Commit

Permalink
Refactor job scheduler and job controller
Browse files Browse the repository at this point in the history
  • Loading branch information
VictoriaBeilsten-Edmands committed Sep 6, 2023
1 parent 0779051 commit 4556b5b
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 241 deletions.
12 changes: 8 additions & 4 deletions ParProcCo/job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .job_scheduler import JobScheduler
from .slicer_interface import SlicerInterface
from .utils import check_location, get_absolute_path
from .utils import check_location, get_absolute_path, get_slurm_token, get_user
from .program_wrapper import ProgramWrapper


Expand All @@ -18,14 +18,16 @@ def __init__(
url: str,
program_wrapper: ProgramWrapper,
output_dir_or_file: Path,
partition: str,
version: str = "v0.0.38",
user_name: Optional[str] = None,
user_token: Optional[str] = None,
timeout: timedelta = timedelta(hours=2),
):
) -> None:
"""JobController is used to coordinate cluster job submissions with JobScheduler"""
self.url = url
self.program_wrapper = program_wrapper
self.partition = partition
self.output_file: Optional[Path] = None
self.cluster_output_dir: Optional[Path] = None

Expand Down Expand Up @@ -54,8 +56,8 @@ def __init__(
logging.debug("JC working dir: %s", self.working_directory)
self.data_slicer: SlicerInterface
self.version = version
self.user_name = user_name if user_name else os.environ["USER"]
self.user_token = user_token if user_token else os.environ["SLURM_JWT"]
self.user_name = user_name if user_name else get_user()
self.user_token = user_token if user_token else get_slurm_token()
self.timeout = timeout
self.sliced_results: Optional[List[Path]] = None
self.aggregated_result: Optional[Path] = None
Expand Down Expand Up @@ -118,6 +120,7 @@ def _run_sliced_jobs(
self.url,
self.working_directory,
self.cluster_output_dir,
self.partition,
self.timeout,
self.version,
self.user_name,
Expand Down Expand Up @@ -158,6 +161,7 @@ def _run_aggregation_job(self, memory: int) -> None:
self.url,
self.working_directory,
self.cluster_output_dir,
self.partition,
self.timeout,
self.version,
self.user_name,
Expand Down
Loading

0 comments on commit 4556b5b

Please sign in to comment.