Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type issues raised by pyright #68

Merged
merged 2 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions ParProcCo/job_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging

import os
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand Down Expand Up @@ -55,7 +54,7 @@ def __init__(
self.working_directory = (
Path(working_directory)
if working_directory
else (self.cluster_output_dir if cluster_output_dir else Path.home())
else (self.cluster_output_dir if self.cluster_output_dir else Path.home())
)
self.resources: Dict[str, str] = {}
if cluster_resources:
Expand Down Expand Up @@ -104,7 +103,7 @@ def run(
self,
scheduler_mode: SchedulerModeInterface,
jobscript: Path,
job_env: Dict[str, str],
job_env: Optional[Dict[str, str]],
memory: str = "4G",
cores: int = 6,
jobscript_args: Optional[List] = None,
Expand Down Expand Up @@ -340,7 +339,7 @@ def _report_job_info(self) -> None:
f" Dispatch time: {time_to_dispatch}; Wall time: {wall_time}."
)

self.job_history[self.batch_number][status_info.job.id] = status_info
self.job_history[self.batch_number][status_info.i] = status_info

def resubmit_jobs(self, job_indices: List[int]) -> bool:
self.batch_number += 1
Expand All @@ -349,8 +348,12 @@ def resubmit_jobs(self, job_indices: List[int]) -> bool:
logging.info(f"Resubmitting jobs with job_indices: {job_indices}")
return self._run_and_monitor(job_indices)

def filter_killed_jobs(self, jobs: List[drmaa2.Job]) -> List[drmaa2.Job]:
killed_jobs = [job for job in jobs if job.info.terminating_signal == "SIGKILL"]
def filter_killed_jobs(self, jobs: List[StatusInfo]) -> List[StatusInfo]:
killed_jobs = [
job
for job in jobs
if job.info is not None and job.info.terminating_signal == "SIGKILL"
]
return killed_jobs

def rerun_killed_jobs(self, allow_all_failed: bool = False):
Expand All @@ -368,7 +371,8 @@ def rerun_killed_jobs(self, allow_all_failed: bool = False):
killed_jobs = self.filter_killed_jobs(failed_jobs)
killed_jobs_indices = [job.i for job in killed_jobs]
logging.info(
f"Total failed_jobs: {len(failed_jobs)}. Total killed_jobs: {len(killed_jobs)}"
f"Total failed_jobs: {len(failed_jobs)}."
f" Total killed_jobs: {len(killed_jobs)}"
)
if killed_jobs_indices:
return self.resubmit_jobs(killed_jobs_indices)
Expand Down
64 changes: 47 additions & 17 deletions ParProcCo/nxdata_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ def _renormalise(self, data_files: List[Path]) -> None:
f"Aggregation completed in {aggregation_time.total_seconds():.3f}s. Sliced file paths: {data_files}."
)

@staticmethod
def _require_dataset(
group: h5py.File | h5py.Group | h5py.AttributeManager, name: str
) -> h5py.Dataset | np.ndarray:
if name in group:
data = group[name]
if isinstance(data, (h5py.Dataset, np.ndarray)):
return data
raise ValueError(f"{name} in {group} must be a dataset (is a {type(data)})")
raise ValueError(f"{name} is not found in {group}")

def _initialise_arrays(self) -> None:
self._get_all_axes()

Expand Down Expand Up @@ -138,17 +149,20 @@ def _get_nxdata(self):
if not self.is_binoculars:
self.is_binoculars = "binoculars" in root
self.nxentry_name = self._get_default_nxgroup(root, "NXentry")
nxentry = root[self.nxentry_name]
nxentry = root.require_group(self.nxentry_name)
self.nxdata_name = self._get_default_nxgroup(nxentry, "NXdata")
self.nxdata_path_name = "/".join([self.nxentry_name, self.nxdata_name])
nxdata = root[self.nxdata_path_name]
nxdata = root.require_group(self.nxdata_path_name)
self._get_default_signals_and_axes(nxdata)

signal_shape = nxdata[self.signal_name].shape
signal = nxdata[self.signal_name]
assert isinstance(signal, h5py.Dataset)
signal_shape = signal.shape
self.data_dimensions = len(signal_shape)

if self.renormalisation:
weights = nxdata["weight"]
assert isinstance(weights, h5py.Dataset)
assert (
len(weights.shape) == self.data_dimensions
), "signal and weight dimensions must match"
Expand All @@ -161,6 +175,7 @@ def _get_default_nxgroup(
) -> str:
if "default" in f.attrs:
group_name = f.attrs["default"]
assert isinstance(group_name, (str, bytes)) # XXX
group_name = decode_to_string(group_name)
class_type = f[group_name].attrs.get("NX_class", "")
class_type = decode_to_string(class_type)
Expand Down Expand Up @@ -196,7 +211,10 @@ def _get_default_signals_and_axes(self, nxdata: h5py.Group) -> None:

if "auxiliary_signals" in nxdata.attrs:
self.aux_signal_names = [
decode_to_string(name) for name in nxdata.attrs["auxiliary_signals"]
decode_to_string(name)
for name in NXdataAggregator._require_dataset(
nxdata.attrs, "auxiliary_signals"
)
]
self.non_weight_aux_signal_names = [
name for name in self.aux_signal_names if name != "weight"
Expand All @@ -209,14 +227,16 @@ def _get_default_signals_and_axes(self, nxdata: h5py.Group) -> None:

if "signal" in nxdata.attrs:
signal_name = nxdata.attrs["signal"]
assert isinstance(signal_name, (str, bytes)) # XXX
self.signal_name = decode_to_string(signal_name)
elif "data" in nxdata.keys():
self.signal_name = "data"

if hasattr(self, "signal_name"):
if "axes" in nxdata.attrs:
self.axes_names = [
decode_to_string(name) for name in nxdata.attrs["axes"]
decode_to_string(name)
for name in NXdataAggregator._require_dataset(nxdata.attrs, "axes")
]
else:
self._generate_axes_names(nxdata)
Expand All @@ -225,7 +245,7 @@ def _get_default_signals_and_axes(self, nxdata: h5py.Group) -> None:

def _generate_axes_names(self, nxdata: h5py.Group) -> None:
self.use_default_axes = True
signal_shape = nxdata[self.signal_name].shape
signal_shape = NXdataAggregator._require_dataset(nxdata, self.signal_name).shape
self.axes_names = [
f"{letter}-axis" for letter in string.ascii_lowercase[: len(signal_shape)]
]
Expand All @@ -235,17 +255,20 @@ def _get_all_axes(self) -> None:
self.all_axes = []
for data_file in self.data_files:
with h5py.File(data_file, "r") as f:
signal_shape = f[self.nxdata_path_name][self.signal_name].shape
nxdata = f.require_group(self.nxdata_path_name)
signal = nxdata[self.signal_name]
assert isinstance(signal, h5py.Dataset)
signal_shape = signal.shape
logging.info(
f"Signal '{'/'.join([self.nxdata_path_name, self.signal_name])}' read from {data_file}. Shape: {signal_shape}"
)
assert len(signal_shape) == self.data_dimensions
self.signal_shapes.append(signal_shape)
if self.aux_signal_names:
for aux_signal_name in self.aux_signal_names:
aux_signal_shape = f[self.nxdata_path_name][
aux_signal_name
].shape
aux_signal_shape = NXdataAggregator._require_dataset(
nxdata, aux_signal_name
).shape
logging.debug(
f"Auxiliary signal '{'/'.join([self.nxdata_path_name, aux_signal_name])}' read from {data_file}. Shape: {aux_signal_shape}"
)
Expand All @@ -256,7 +279,7 @@ def _get_all_axes(self) -> None:
axes = [np.arange(length) for length in signal_shape]
else:
axes = [
f[self.nxdata_path_name][axis_name][...]
NXdataAggregator._require_dataset(nxdata, axis_name)[...]
for axis_name in self.axes_names
]
self.all_axes.append(axes)
Expand All @@ -271,21 +294,27 @@ def _accumulate_volumes(self) -> None:
f"Accumulating volume with shape {self.accumulator_volume.shape} and axes {self.axes_names}"
)
for data_file, slices in zip(self.data_files, self.all_slices):
weights = None
with h5py.File(data_file, "r") as f:
aux_signals = []
volume = f[self.nxdata_path_name][self.signal_name][...]
nxdata = f.require_group(self.nxdata_path_name)
volume = NXdataAggregator._require_dataset(nxdata, self.signal_name)[
...
]
logging.debug(
f"Reading volume from {'/'.join([self.nxdata_path_name, self.signal_name])} in {data_file}. Shape is {volume.shape}"
)
if self.renormalisation:
weights = f[self.nxdata_path_name]["weight"][...]
weights = NXdataAggregator._require_dataset(nxdata, "weight")[...]
for name in self.non_weight_aux_signal_names:
aux_signals.append(f[self.nxdata_path_name][name][...])
aux_signals.append(
NXdataAggregator._require_dataset(nxdata, name)[...]
)
logging.debug(
f"Reading auxiliary signal from {'/'.join([self.nxdata_path_name, name])} in {data_file}"
)

if self.renormalisation:
if self.renormalisation and weights is not None:
volume = np.multiply(volume, weights)
self.accumulator_weights[slices] += weights
aux_signals = [
Expand Down Expand Up @@ -351,9 +380,10 @@ def _write_aggregation_file(self, aggregation_output: Path) -> Path:

f.attrs["default"] = self.nxentry_name

old_processed = None
for i, filepath in enumerate(self.data_files):
with h5py.File(filepath, "r") as df:
data_nxentry_group = df[self.nxentry_name]
data_nxentry_group = df.require_group(self.nxentry_name)
group_name = self._get_group_name(data_nxentry_group, "NXprocess")
for j, name in enumerate(group_name):
if "old_processed" not in f:
Expand All @@ -366,7 +396,7 @@ def _write_aggregation_file(self, aggregation_output: Path) -> Path:
name, old_processed, name=f"process{i}.{j}"
)
logging.info(
f"Copied '{'/'.join([data_nxentry_group.name, name])}' group in {filepath} to"
f"Copied '{'/'.join([data_nxentry_group.name, name])}' group in {filepath} to" # pyright: ignore[reportGeneralTypeIssues]
f" '{'/'.join(['old_processed', f'process{i}.{j}'])}' group in {aggregation_output}"
)

Expand Down
4 changes: 2 additions & 2 deletions ParProcCo/simple_data_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def slice(
self, number_jobs: int, stop: Optional[int] = None
) -> List[Optional[slice]]:
"""Overrides SlicerInterface.slice"""
if type(number_jobs) is not int:
if not isinstance(number_jobs, int):
raise TypeError(f"number_jobs is {type(number_jobs)}, should be int\n")

if (stop is not None) and (type(stop) is not int):
if (stop is not None) and not isinstance(stop, int):
raise TypeError(f"stop is {type(stop)}, should be int or None\n")

if stop:
Expand Down
23 changes: 12 additions & 11 deletions ParProcCo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_filepath_on_path(filename: Optional[str]) -> Optional[Path]:
)
try:
filepath = next(path_gen)
return filepath
return Path(filepath)
except StopIteration:
raise FileNotFoundError(f"{filename} not found on PATH {paths}")

Expand All @@ -49,7 +49,7 @@ def check_location(location: Union[Path, str]) -> Path:
)


def format_timestamp(t: datetime.datetime) -> str:
def format_timestamp(t: datetime) -> str:
return t.strftime("%Y%m%d_%H%M")


Expand All @@ -58,15 +58,16 @@ def decode_to_string(any_string: Union[bytes, str]) -> str:
return output


def get_absolute_path(filename: Union[Path, str]) -> str:
p = Path(filename).resolve()
if p.is_file():
return str(p)
from shutil import which
def get_absolute_path(filename: Path | str | None) -> str:
if filename is not None:
p = Path(filename).resolve()
if p.is_file():
return str(p)
from shutil import which

f = which(filename)
if f:
return f
f = which(filename)
if f:
return str(f)
raise ValueError(f"{filename} not found")


Expand Down Expand Up @@ -199,7 +200,7 @@ def find_cfg_file(name: str) -> Path:
""" """
cp = os.getenv("PPC_CONFIG")
if cp:
return cp
return Path(cp)

cp = Path.home() / ("." + name)
if cp.is_file():
Expand Down
4 changes: 2 additions & 2 deletions example/simple_aggregation_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


class SimpleAggregationMode(SchedulerModeInterface):
def __init__(self, program: str) -> None:
self.program_name = program
def __init__(self, program: Path) -> None:
self.program_name = program.name
self.cores = 1

def set_parameters(self, sliced_results: List) -> None:
Expand Down
2 changes: 1 addition & 1 deletion example/simple_processing_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class SimpleProcessingMode(SchedulerModeInterface):
def __init__(self, program: Optional[Path] = None) -> None:
self.program_name: Optional[str] = program
self.program_name: Optional[str] = None if program is None else program.name
self.cores = 1
self.allowed_modules = ("python",)

Expand Down
9 changes: 6 additions & 3 deletions tests/test_job_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_all_jobs_fail(self) -> None:
[str(runner_script.parent), self.starting_path]
)

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_end_to_end(self) -> None:
input_path = setup_data_file(working_directory)
runner_script_args = [jobscript.name, "--input-path", str(input_path)]

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand All @@ -104,12 +104,14 @@ def test_end_to_end(self) -> None:
)
jc.run(4, jobscript_args=runner_script_args)

assert jc.aggregated_result
with open(jc.aggregated_result, "r") as af:
agg_data = af.readlines()

self.assertEqual(
agg_data, ["0\n", "8\n", "2\n", "10\n", "4\n", "12\n", "6\n", "14\n"]
)
assert jc.sliced_results
for result in jc.sliced_results:
self.assertFalse(result.is_file())

Expand All @@ -134,7 +136,7 @@ def test_single_job_does_not_aggregate(self) -> None:
Path(working_directory) / cluster_output_name / "aggregated_results.txt"
)

wrapper = SimpleWrapper(runner_script.name, aggregation_script.name)
wrapper = SimpleWrapper(runner_script, aggregation_script)
wrapper.set_cores(6)
jc = JobController(
wrapper,
Expand All @@ -145,6 +147,7 @@ def test_single_job_does_not_aggregate(self) -> None:
)
jc.run(1, jobscript_args=runner_script_args)

assert jc.sliced_results
self.assertEqual(len(jc.sliced_results), 1)
self.assertFalse(aggregated_file.is_file())
self.assertTrue(jc.sliced_results[0].is_file())
Expand Down
Loading
Loading