Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type issues raised by pyright #68

Merged
merged 2 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ParProcCo/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
self.working_directory = (
Path(working_directory)
if working_directory
else (self.cluster_output_dir if cluster_output_dir else Path.home())
else (self.cluster_output_dir if self.cluster_output_dir else Path.home())
)
self.resources: Dict[str, str] = {}
if cluster_resources:
Expand Down Expand Up @@ -104,7 +104,7 @@ def run(
self,
scheduler_mode: SchedulerModeInterface,
jobscript: Path,
job_env: Dict[str, str],
job_env: Optional[Dict[str, str]],
memory: str = "4G",
cores: int = 6,
jobscript_args: Optional[List] = None,
Expand Down
23 changes: 12 additions & 11 deletions ParProcCo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_filepath_on_path(filename: Optional[str]) -> Optional[Path]:
)
try:
filepath = next(path_gen)
return filepath
return Path(filepath)
except StopIteration:
raise FileNotFoundError(f"{filename} not found on PATH {paths}")

Expand All @@ -49,7 +49,7 @@ def check_location(location: Union[Path, str]) -> Path:
)


def format_timestamp(t: datetime.datetime) -> str:
def format_timestamp(t: datetime) -> str:
return t.strftime("%Y%m%d_%H%M")


Expand All @@ -58,15 +58,16 @@ def decode_to_string(any_string: Union[bytes, str]) -> str:
return output


def get_absolute_path(filename: Union[Path, str]) -> str:
p = Path(filename).resolve()
if p.is_file():
return str(p)
from shutil import which
def get_absolute_path(filename: Path | str | None) -> str:
if filename is not None:
p = Path(filename).resolve()
if p.is_file():
return str(p)
from shutil import which

f = which(filename)
if f:
return f
f = which(filename)
if f:
return str(f)
raise ValueError(f"{filename} not found")


Expand Down Expand Up @@ -199,7 +200,7 @@ def find_cfg_file(name: str) -> Path:
""" """
cp = os.getenv("PPC_CONFIG")
if cp:
return cp
return Path(cp)

cp = Path.home() / ("." + name)
if cp.is_file():
Expand Down
4 changes: 2 additions & 2 deletions example/simple_aggregation_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


class SimpleAggregationMode(SchedulerModeInterface):
def __init__(self, program: str) -> None:
self.program_name = program
def __init__(self, program: Path) -> None:
self.program_name = program.name
self.cores = 1

def set_parameters(self, sliced_results: List) -> None:
Expand Down
2 changes: 1 addition & 1 deletion example/simple_processing_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class SimpleProcessingMode(SchedulerModeInterface):
def __init__(self, program: Optional[Path] = None) -> None:
self.program_name: Optional[str] = program
self.program_name: Optional[str] = None if program is None else program.name
self.cores = 1
self.allowed_modules = ("python",)

Expand Down
9 changes: 6 additions & 3 deletions tests/test_job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_all_jobs_fail(self) -> None:
[str(runner_script.parent), self.starting_path]
)

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_end_to_end(self) -> None:
input_path = setup_data_file(working_directory)
runner_script_args = [jobscript.name, "--input-path", str(input_path)]

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand All @@ -104,12 +104,14 @@ def test_end_to_end(self) -> None:
)
jc.run(4, jobscript_args=runner_script_args)

assert jc.aggregated_result
with open(jc.aggregated_result, "r") as af:
agg_data = af.readlines()

self.assertEqual(
agg_data, ["0\n", "8\n", "2\n", "10\n", "4\n", "12\n", "6\n", "14\n"]
)
assert jc.sliced_results
for result in jc.sliced_results:
self.assertFalse(result.is_file())

Expand All @@ -134,7 +136,7 @@ def test_single_job_does_not_aggregate(self) -> None:
Path(working_directory) / cluster_output_name / "aggregated_results.txt"
)

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand All @@ -145,6 +147,7 @@ def test_single_job_does_not_aggregate(self) -> None:
)
jc.run(1, jobscript_args=runner_script_args)

assert jc.sliced_results
self.assertEqual(len(jc.sliced_results), 1)
self.assertFalse(aggregated_file.is_file())
self.assertTrue(jc.sliced_results[0].is_file())
Expand Down
Loading