Skip to content

Commit

Permalink
Fix type issues raised by pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterC-DLS committed Aug 31, 2023
1 parent 99ff6a6 commit 986a317
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 19 deletions.
4 changes: 2 additions & 2 deletions ParProcCo/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
self.working_directory = (
Path(working_directory)
if working_directory
else (self.cluster_output_dir if cluster_output_dir else Path.home())
else (self.cluster_output_dir if self.cluster_output_dir else Path.home())
)
self.resources: Dict[str, str] = {}
if cluster_resources:
Expand Down Expand Up @@ -104,7 +104,7 @@ def run(
self,
scheduler_mode: SchedulerModeInterface,
jobscript: Path,
job_env: Dict[str, str],
job_env: Optional[Dict[str, str]],
memory: str = "4G",
cores: int = 6,
jobscript_args: Optional[List] = None,
Expand Down
23 changes: 12 additions & 11 deletions ParProcCo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_filepath_on_path(filename: Optional[str]) -> Optional[Path]:
)
try:
filepath = next(path_gen)
return filepath
return Path(filepath)
except StopIteration:
raise FileNotFoundError(f"{filename} not found on PATH {paths}")

Expand All @@ -49,7 +49,7 @@ def check_location(location: Union[Path, str]) -> Path:
)


def format_timestamp(t: datetime.datetime) -> str:
def format_timestamp(t: datetime) -> str:
return t.strftime("%Y%m%d_%H%M")


Expand All @@ -58,15 +58,16 @@ def decode_to_string(any_string: Union[bytes, str]) -> str:
return output


def get_absolute_path(filename: Union[Path, str]) -> str:
p = Path(filename).resolve()
if p.is_file():
return str(p)
from shutil import which
def get_absolute_path(filename: Path | str | None) -> str:
if filename is not None:
p = Path(filename).resolve()
if p.is_file():
return str(p)
from shutil import which

f = which(filename)
if f:
return f
f = which(filename)
if f:
return str(f)
raise ValueError(f"{filename} not found")


Expand Down Expand Up @@ -199,7 +200,7 @@ def find_cfg_file(name: str) -> Path:
""" """
cp = os.getenv("PPC_CONFIG")
if cp:
return cp
return Path(cp)

cp = Path.home() / ("." + name)
if cp.is_file():
Expand Down
4 changes: 2 additions & 2 deletions example/simple_aggregation_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


class SimpleAggregationMode(SchedulerModeInterface):
def __init__(self, program: str) -> None:
self.program_name = program
def __init__(self, program: Path) -> None:
self.program_name = program.name
self.cores = 1

def set_parameters(self, sliced_results: List) -> None:
Expand Down
2 changes: 1 addition & 1 deletion example/simple_processing_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class SimpleProcessingMode(SchedulerModeInterface):
def __init__(self, program: Optional[Path] = None) -> None:
self.program_name: Optional[str] = program
self.program_name: Optional[str] = None if program is None else program.name
self.cores = 1
self.allowed_modules = ("python",)

Expand Down
9 changes: 6 additions & 3 deletions tests/test_job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_all_jobs_fail(self) -> None:
[str(runner_script.parent), self.starting_path]
)

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_end_to_end(self) -> None:
input_path = setup_data_file(working_directory)
runner_script_args = [jobscript.name, "--input-path", str(input_path)]

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand All @@ -104,12 +104,14 @@ def test_end_to_end(self) -> None:
)
jc.run(4, jobscript_args=runner_script_args)

assert jc.aggregated_result
with open(jc.aggregated_result, "r") as af:
agg_data = af.readlines()

self.assertEqual(
agg_data, ["0\n", "8\n", "2\n", "10\n", "4\n", "12\n", "6\n", "14\n"]
)
assert jc.sliced_results
for result in jc.sliced_results:
self.assertFalse(result.is_file())

Expand All @@ -134,7 +136,7 @@ def test_single_job_does_not_aggregate(self) -> None:
Path(working_directory) / cluster_output_name / "aggregated_results.txt"
)

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand All @@ -145,6 +147,7 @@ def test_single_job_does_not_aggregate(self) -> None:
)
jc.run(1, jobscript_args=runner_script_args)

assert jc.sliced_results
self.assertEqual(len(jc.sliced_results), 1)
self.assertFalse(aggregated_file.is_file())
self.assertTrue(jc.sliced_results[0].is_file())
Expand Down

0 comments on commit 986a317

Please sign in to comment.