Skip to content

Commit

Permalink
Add dask scheduler switch (#430)
Browse files Browse the repository at this point in the history
* Added initial scheduling update

* docstring

* Added scheduler option to pipeline.py

* Fix CPU count handling

* Initial find_sources testing

* PEP8

* Updated existing testing and added new testing for invalid scheduler failure

* Fixed test name + docstring

* Updated changelog

* Added pipeline update

* Added base_folder to test init

* Updated find_sources?
  • Loading branch information
ddobie authored May 18, 2023
1 parent 46f3ba6 commit 9b9bbda
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased](https://github.com/askap-vast/vast-tools/compare/v2.0.0...HEAD)

#### Added
- Added option to specify which dask scheduler to use to Query, PipeRun, PipeAnalysis objects, and argument to find_sources.py [#430](https://github.com/askap-vast/vast-tools/pull/430)
- Added access to epoch 32 [#429](https://github.com/askap-vast/vast-tools/pull/429)
- Added access to epoch 31 [#427](https://github.com/askap-vast/vast-tools/pull/427)
- Added access to epoch 30 [#419](https://github.com/askap-vast/vast-tools/pull/419)
Expand Down Expand Up @@ -79,6 +80,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

#### List of PRs

- [#430](https://github.com/askap-vast/vast-tools/pull/430): feat: Allow users to specify which dask scheduler to use for multi-processing
- [#421](https://github.com/askap-vast/vast-tools/pull/421): feat, fix, docs: Updated Query._get_epochs to exit nicely when no epochs available & to allow lists and ints to be passed.
- [#429](https://github.com/askap-vast/vast-tools/pull/429): feat: Added access to epoch 32
- [#427](https://github.com/askap-vast/vast-tools/pull/427): feat: Added access to epoch 31
Expand Down
42 changes: 40 additions & 2 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,41 @@ def test_init_failure_no_sources_in_footprint(
' are found in the VAST Pilot survey footprint!'
)

def test_init_failure_invalid_scheduler(self,
mocker: MockerFixture
) -> None:
"""
Tests the initialisation failure of a Query object.
Specifically when the requested dask scheduler is invalid.
Args:
mocker: The pytest-mock mocker object.
Returns:
None
"""
isdir_mocker = mocker.patch(
'vasttools.query.os.path.isdir',
return_value=True
)
mocker_data_available = mocker.patch(
'vasttools.query.Query._check_data_availability',
return_value=True
)

with pytest.raises(vtq.QueryInitError) as excinfo:
query = vtq.Query(
planets=['Mars'],
scheduler='bad-option',
base_folder='/testing/folder'
)

assert str(excinfo.value) == (
"bad-option is not a suitable scheduler option. Please "
"select from ['processes', 'single-threaded']"
)

def test_init_settings(self, mocker: MockerFixture) -> None:
"""
Tests the initialisation of a Query object.
Expand Down Expand Up @@ -721,6 +756,7 @@ def test_init_settings(self, mocker: MockerFixture) -> None:
forced_cluster_threshold = 7.5
output_dir = '/output/here'
incl_observed = False
scheduler = 'processes'

expected_settings = {
'epochs': ["1", "2", "3x"],
Expand All @@ -737,7 +773,8 @@ def test_init_settings(self, mocker: MockerFixture) -> None:
'output_dir': output_dir,
'search_around': False,
'tiles': use_tiles,
'incl_observed': False
'incl_observed': False,
'scheduler': 'processes'
}

query = vtq.Query(
Expand All @@ -756,7 +793,8 @@ def test_init_settings(self, mocker: MockerFixture) -> None:
forced_allow_nan=forced_allow_nan,
forced_cluster_threshold=forced_cluster_threshold,
output_dir=output_dir,
incl_observed=incl_observed
incl_observed=incl_observed,
scheduler=scheduler
)

assert query.settings == expected_settings
Expand Down
10 changes: 9 additions & 1 deletion vasttools/bin/find_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ def parse_args() -> argparse.Namespace:
'--clobber',
action="store_true",
help=("Overwrite the output directory if it already exists."))
parser.add_argument(
'--scheduler',
default='processes',
choices=['processes', 'single-threaded'],
help=("Dask scheduling option to use. Options are 'processes' "
"(parallel processing) or 'single-threaded'.")
)
parser.add_argument(
'--sort-output',
action="store_true",
Expand Down Expand Up @@ -508,7 +515,8 @@ def main() -> None:
forced_cluster_threshold=args.forced_cluster_threshold,
forced_allow_nan=args.forced_allow_nan,
incl_observed=args.find_fields,
corrected_data=not args.uncorrected_data
corrected_data=not args.uncorrected_data,
scheduler=args.scheduler,
)

if args.find_fields:
Expand Down
26 changes: 18 additions & 8 deletions vasttools/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def __init__(
measurements: Union[pd.DataFrame, vaex.dataframe.DataFrame],
measurement_pairs_file: List[str],
vaex_meas: bool = False,
n_workers: int = cpu_count() - 1
n_workers: int = HOST_NCPU - 1,
scheduler: str = 'processes'
) -> None:
"""
Constructor method.
Expand Down Expand Up @@ -135,6 +136,9 @@ def __init__(
loaded into a pandas DataFrame.
n_workers: Number of workers (cpus) available. Default is
determined by running `cpu_count()`.
scheduler: Dask scheduling option to use. Options are "processes"
(parallel processing) or "single-threaded". Defaults to
"single-threaded".
Returns:
None
Expand All @@ -153,6 +157,7 @@ def __init__(
self.n_workers = n_workers
self._vaex_meas = vaex_meas
self._loaded_two_epoch_metrics = False
self.scheduler = scheduler

self.logger = logging.getLogger('vasttools.pipeline.PipeRun')
self.logger.debug('Created PipeRun instance')
Expand Down Expand Up @@ -652,7 +657,7 @@ def check_for_planets(self) -> pd.DataFrame:
match_planet_to_field,
meta=meta
).compute(
scheduler='processes',
scheduler=self.scheduler,
n_workers=self.n_workers
)
)
Expand Down Expand Up @@ -821,7 +826,8 @@ def __init__(
measurements: Union[pd.DataFrame, vaex.dataframe.DataFrame],
measurement_pairs_file: str,
vaex_meas: bool = False,
n_workers: int = cpu_count() - 1
n_workers: int = HOST_NCPU - 1,
scheduler: str = 'processes',
) -> None:
"""
Constructor method.
Expand Down Expand Up @@ -856,13 +862,17 @@ def __init__(
vaex from an arrow file. `False` means the measurements are
loaded into a pandas DataFrame.
n_workers: Number of workers (cpus) available.
scheduler: Dask scheduling option to use. Options are "processes"
(parallel processing) or "single-threaded". Defaults to
"single-threaded".
Returns:
None
"""
super().__init__(
name, images, skyregions, relations, sources, associations,
bands, measurements, measurement_pairs_file, vaex_meas, n_workers
bands, measurements, measurement_pairs_file, vaex_meas, n_workers,
scheduler
)

def _filter_meas_pairs_df(
Expand Down Expand Up @@ -1127,13 +1137,13 @@ def recalc_sources_df(
}

sources_df_fluxes = (
dd.from_pandas(measurements_df_temp, HOST_NCPU)
dd.from_pandas(measurements_df_temp, self.n_workers)
.groupby('source')
.apply(
pipeline_get_variable_metrics,
meta=col_dtype
)
.compute(num_workers=HOST_NCPU - 1, scheduler='processes')
.compute(num_workers=self.n_workers, scheduler=self.scheduler)
)

# Switch to pandas at this point to perform join
Expand Down Expand Up @@ -2461,7 +2471,7 @@ def list_images(self) -> List[str]:

def load_runs(
self, run_names: List[str], name: Optional[str] = None,
n_workers: int = cpu_count() - 1
n_workers: int = HOST_NCPU - 1
) -> PipeAnalysis:
"""
Wrapper to load multiple runs in one command.
Expand Down Expand Up @@ -2493,7 +2503,7 @@ def load_runs(
return piperun

def load_run(
self, run_name: str, n_workers: int = cpu_count() - 1
self, run_name: str, n_workers: int = HOST_NCPU - 1
) -> PipeAnalysis:
"""
Process and load a pipeline run.
Expand Down
46 changes: 37 additions & 9 deletions vasttools/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def __init__(
forced_cluster_threshold: float = 1.5,
forced_allow_nan: bool = False,
incl_observed: bool = False,
corrected_data: bool = True
corrected_data: bool = True,
scheduler: str = 'processes',
) -> None:
"""
Constructor method.
Expand Down Expand Up @@ -177,6 +178,9 @@ def __init__(
fields, not querying data. Defaults to False.
corrected_data: Access the corrected data. Only relevant if
`tiles` is `True`. Defaults to `True`.
scheduler: Dask scheduling option to use. Options are "processes"
(parallel processing) or "single-threaded". Defaults to
"single-threaded".
Returns:
None
Expand All @@ -195,6 +199,7 @@ def __init__(
QueryInitError: Base folder cannot be found.
QueryInitError: Base folder cannot be found.
QueryInitError: Problems found in query settings.
QueryInitError: Invalid scheduler option requested.
"""
self.logger = logging.getLogger('vasttools.find_sources.Query')

Expand Down Expand Up @@ -357,6 +362,14 @@ def __init__(

self.settings['output_dir'] = output_dir

scheduler_options = ['processes', 'single-threaded']
if scheduler not in scheduler_options:
raise QueryInitError(
f"{scheduler} is not a suitable scheduler option. Please "
f"select from {scheduler_options}"
)
self.settings['scheduler'] = scheduler

# Going to need this so load it now
self._epoch_fields = get_fields_per_epoch_info()

Expand Down Expand Up @@ -560,7 +573,9 @@ def _get_all_cutout_data(self, imsize: Angle) -> pd.DataFrame:
self._grouped_fetch_cutouts,
imsize=imsize,
meta=meta,
).compute(num_workers=self.ncpu, scheduler='processes')
).compute(num_workers=self.ncpu,
scheduler=self.settings['scheduler']
)
)

if not cutouts.empty:
Expand Down Expand Up @@ -1185,7 +1200,9 @@ def find_sources(self) -> None:
),
allow_nan=self.settings['forced_allow_nan'],
meta=meta,
).compute(num_workers=self.ncpu, scheduler='processes')
).compute(num_workers=self.ncpu,
scheduler=self.settings['scheduler']
)
)

if not f_results.empty:
Expand All @@ -1203,7 +1220,9 @@ def find_sources(self) -> None:
.apply(
self._get_components,
meta=self._get_selavy_meta(),
).compute(num_workers=self.ncpu, scheduler='processes')
).compute(num_workers=self.ncpu,
scheduler=self.settings['scheduler']
)
)

self.logger.debug("Selavy components succesfully added.")
Expand Down Expand Up @@ -1255,7 +1274,9 @@ def find_sources(self) -> None:
.apply(
self._init_sources,
meta=meta,
).compute(num_workers=npart, scheduler='processes')
).compute(num_workers=npart,
scheduler=self.settings['scheduler']
)
)
self.results = self.results.dropna()

Expand Down Expand Up @@ -1291,7 +1312,9 @@ def save_search_around_results(self, sort_output: bool = False) -> None:
self._write_search_around_results,
sort_output=sort_output,
meta=meta,
).compute(num_workers=self.ncpu, scheduler='processes')
).compute(num_workers=self.ncpu,
scheduler=self.settings['scheduler']
)
)

def _write_search_around_results(
Expand Down Expand Up @@ -2084,7 +2107,9 @@ def find_fields(self) -> None:
meta=meta,
axis=1,
result_type='expand'
).compute(num_workers=self.ncpu, scheduler='processes')
).compute(num_workers=self.ncpu,
scheduler=self.settings['scheduler']
)
)

self.logger.debug("Finished field matching.")
Expand Down Expand Up @@ -2377,7 +2402,9 @@ def _search_planets(self) -> pd.DataFrame:
.apply(
match_planet_to_field,
meta=meta,
).compute(num_workers=self.ncpu, scheduler='processes')
).compute(num_workers=self.ncpu,
scheduler=self.settings['scheduler']
)
)

results = results.reset_index(drop=True).drop(
Expand Down Expand Up @@ -2444,7 +2471,8 @@ def _build_catalog(self) -> pd.DataFrame:
)
if mask.any():
self.logger.warning(
f"Removing {sum(mask)} sources outside the requested survey footprint"
f"Removing {sum(mask)} sources outside the requested "
f"survey footprint."
)
self.coords = self.coords[~mask]
self.source_names = self.source_names[~mask]
Expand Down

0 comments on commit 9b9bbda

Please sign in to comment.