diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index 80352a0bb29..a188b4260d0 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -4,13 +4,13 @@ import astropy.units as u import numpy as np -from ctapipe.containers import CoordinateFrameType from ctapipe.core import Tool from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path -from ctapipe.exceptions import TooFewEvents from ctapipe.io import TableLoader from ctapipe.reco import CrossValidator, DispReconstructor -from ctapipe.reco.preprocessing import check_valid_rows, horizontal_to_telescope +from ctapipe.reco.preprocessing import horizontal_to_telescope + +from .utils import read_training_events __all__ = [ "TrainDispReconstructor", @@ -56,6 +56,12 @@ class TrainDispReconstructor(Tool): ), ).tag(config=True) + chunk_size = Int( + default_value=100000, + allow_none=True, + help="How many subarray events to load at once before training on n_events.", + ).tag(config=True) + random_seed = Int( default_value=0, help="Random seed for sampling and cross validation" ).tag(config=True) @@ -111,7 +117,28 @@ def start(self): self.log.info("Training models for %d types", len(types)) for tel_type in types: self.log.info("Loading events for %s", tel_type) - table = self._read_table(tel_type) + feature_names = self.models.features + [ + "true_energy", + "subarray_pointing_lat", + "subarray_pointing_lon", + "true_alt", + "true_az", + "hillas_fov_lat", + "hillas_fov_lon", + "hillas_psi", + ] + table = read_training_events( + loader=self.loader, + chunk_size=self.chunk_size, + telescope_type=tel_type, + reconstructor=self.models, + feature_names=feature_names, + rng=self.rng, + log=self.log, + n_events=self.n_events.tel[tel_type], + ) + table[self.models.target] = self._get_true_disp(table) + table = table[self.models.features + [self.models.target, "true_energy"]] self.log.info("Train models on %s events", len(table)) self.cross_validate(tel_type, table) @@ -120,58 +147,6 @@ def start(self): self.models.fit(tel_type, table) self.log.info("done") - def _read_table(self, telescope_type): - table = self.loader.read_telescope_events([telescope_type]) - self.log.info("Events read from input: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"Input file does not contain any events for telescope type {telescope_type}" - ) - - mask = self.models.quality_query.get_table_mask(table) - table = table[mask] - self.log.info("Events after applying quality query: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"No events after quality query for telescope type {telescope_type}" - ) - - if not np.all( - table["subarray_pointing_frame"] == CoordinateFrameType.ALTAZ.value - ): - raise ValueError( - "Pointing information for training data has to be provided in horizontal coordinates" - ) - - table = self.models.feature_generator(table, subarray=self.loader.subarray) - - table[self.models.target] = self._get_true_disp(table) - - # Add true energy for energy-dependent performance plots - columns = self.models.features + [self.models.target, "true_energy"] - table = table[columns] - - valid = check_valid_rows(table) - if np.any(~valid): - self.log.warning("Dropping non-predicable events.") - table = table[valid] - - n_events = self.n_events.tel[telescope_type] - if n_events is not None: - if n_events > len(table): - self.log.warning( - "Number of events in table (%d) is less than requested number of events %d", - len(table), - n_events, - ) - else: - self.log.info("Sampling %d events", n_events) - idx = self.rng.choice(len(table), n_events, replace=False) - idx.sort() - table = table[idx] - - return table - def _get_true_disp(self, table): fov_lon, fov_lat = horizontal_to_telescope( alt=table["true_alt"], diff --git a/ctapipe/tools/train_energy_regressor.py b/ctapipe/tools/train_energy_regressor.py index 89647f6c39b..2c9d411f594 100644 --- a/ctapipe/tools/train_energy_regressor.py +++ b/ctapipe/tools/train_energy_regressor.py @@ -5,10 +5,10 @@ from ctapipe.core import Tool from ctapipe.core.traits import Int, IntTelescopeParameter, Path -from ctapipe.exceptions import TooFewEvents from ctapipe.io import TableLoader from ctapipe.reco import CrossValidator, EnergyRegressor -from ctapipe.reco.preprocessing import check_valid_rows + +from .utils import read_training_events __all__ = [ "TrainEnergyRegressor", @@ -53,6 +53,12 @@ class TrainEnergyRegressor(Tool): ), ).tag(config=True) + chunk_size = Int( + default_value=100000, + allow_none=True, + help="How many subarray events to load at once before training on n_events.", + ).tag(config=True) + random_seed = Int( default_value=0, help="Random seed for sampling and cross validation" ).tag(config=True) @@ -61,6 +67,7 @@ class TrainEnergyRegressor(Tool): ("i", "input"): "TableLoader.input_url", ("o", "output"): "TrainEnergyRegressor.output_path", "n-events": "TrainEnergyRegressor.n_events", + "chunk-size": "TrainEnergyRegressor.chunk_size", "cv-output": "CrossValidator.output_path", } @@ -103,7 +110,17 @@ def start(self): self.log.info("Training models for %d types", len(types)) for tel_type in types: self.log.info("Loading events for %s", tel_type) - table = self._read_table(tel_type) + feature_names = self.regressor.features + [self.regressor.target] + table = read_training_events( + loader=self.loader, + chunk_size=self.chunk_size, + telescope_type=tel_type, + reconstructor=self.regressor, + feature_names=feature_names, + rng=self.rng, + log=self.log, + n_events=self.n_events.tel[tel_type], + ) self.log.info("Train on %s events", len(table)) self.cross_validate(tel_type, table) @@ -112,48 +129,6 @@ def start(self): self.regressor.fit(tel_type, table) self.log.info("done") - def _read_table(self, telescope_type): - table = self.loader.read_telescope_events([telescope_type]) - self.log.info("Events read from input: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"Input file does not contain any events for telescope type {telescope_type}" - ) - - mask = self.regressor.quality_query.get_table_mask(table) - table = table[mask] - self.log.info("Events after applying quality query: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"No events after quality query for telescope type {telescope_type}" - ) - - table = self.regressor.feature_generator(table, subarray=self.loader.subarray) - - feature_names = self.regressor.features + [self.regressor.target] - table = table[feature_names] - - valid = check_valid_rows(table) - if np.any(~valid): - self.log.warning("Dropping non-predictable events.") - table = table[valid] - - n_events = self.n_events.tel[telescope_type] - if n_events is not None: - if n_events > len(table): - self.log.warning( - "Number of events in table (%d) is less than requested number of events %d", - len(table), - n_events, - ) - else: - self.log.info("Sampling %d events", n_events) - idx = self.rng.choice(len(table), n_events, replace=False) - idx.sort() - table = table[idx] - - return table - def finish(self): """ Write-out trained models and cross-validation results. diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index b8511c6fd1e..70c53b63419 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -6,10 +6,10 @@ from ctapipe.core.tool import Tool from ctapipe.core.traits import Int, IntTelescopeParameter, Path -from ctapipe.exceptions import TooFewEvents from ctapipe.io import TableLoader from ctapipe.reco import CrossValidator, ParticleClassifier -from ctapipe.reco.preprocessing import check_valid_rows + +from .utils import read_training_events __all__ = [ "TrainParticleClassifier", @@ -78,6 +78,15 @@ class TrainParticleClassifier(Tool): ), ).tag(config=True) + chunk_size = Int( + default_value=100000, + allow_none=True, + help=( + "How many subarray events to load at once before training on" + " n_signal and n_background events." + ), + ).tag(config=True) + random_seed = Int( default_value=0, help="Random number seed for sampling and the cross validation splitting", @@ -161,54 +170,30 @@ def start(self): self.classifier.fit(tel_type, table) self.log.info("done") - def _read_table(self, telescope_type, loader, n_events=None): - table = loader.read_telescope_events([telescope_type]) - self.log.info("Events read from input: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"Input file does not contain any events for telescope type {telescope_type}" - ) - - mask = self.classifier.quality_query.get_table_mask(table) - table = table[mask] - self.log.info("Events after applying quality query: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"No events after quality query for telescope type {telescope_type}" - ) - - table = self.classifier.feature_generator(table, subarray=self.subarray) - - # Add true energy for energy-dependent performance plots - columns = self.classifier.features + [self.classifier.target, "true_energy"] - table = table[columns] - - valid = check_valid_rows(table) - if np.any(~valid): - self.log.warning("Dropping non-predictable events.") - table = table[valid] - - if n_events is not None: - if n_events > len(table): - self.log.warning( - "Number of events in table (%d) is less than requested number of events %d", - len(table), - n_events, - ) - else: - self.log.info("Sampling %d events", n_events) - idx = self.rng.choice(len(table), n_events, replace=False) - idx.sort() - table = table[idx] - - return table - def _read_input_data(self, tel_type): - signal = self._read_table( - tel_type, self.signal_loader, self.n_signal.tel[tel_type] + feature_names = self.classifier.features + [ + self.classifier.target, + "true_energy", + ] + signal = read_training_events( + loader=self.signal_loader, + chunk_size=self.chunk_size, + telescope_type=tel_type, + reconstructor=self.classifier, + feature_names=feature_names, + rng=self.rng, + log=self.log, + n_events=self.n_signal.tel[tel_type], ) - background = self._read_table( - tel_type, self.background_loader, self.n_background.tel[tel_type] + background = read_training_events( + loader=self.background_loader, + chunk_size=self.chunk_size, + telescope_type=tel_type, + reconstructor=self.classifier, + feature_names=feature_names, + rng=self.rng, + log=self.log, + n_events=self.n_background.tel[tel_type], ) table = vstack([signal, background]) self.log.info( diff --git a/ctapipe/tools/utils.py b/ctapipe/tools/utils.py index 131046fbc1c..9e7c76eb837 100644 --- a/ctapipe/tools/utils.py +++ b/ctapipe/tools/utils.py @@ -2,8 +2,23 @@ """Utils to create scripts and command-line tools""" import argparse import importlib +import logging import sys from collections import OrderedDict +from typing import Type + +import numpy as np +from astropy.table import vstack + +from ..containers import CoordinateFrameType +from ..core.traits import Int +from ..exceptions import TooFewEvents +from ..instrument.telescope import TelescopeDescription +from ..io import TableLoader +from ..reco.preprocessing import check_valid_rows +from ..reco.sklearn import DispReconstructor, SKLearnReconstructor + +LOG = logging.getLogger(__name__) if sys.version_info < (3, 10): from importlib_metadata import distribution @@ -71,3 +86,87 @@ def get_all_descriptions(): descriptions[name] = "[no documentation. Please add a docstring]" return descriptions + + +def read_training_events( + loader: TableLoader, + chunk_size: Int, + telescope_type: TelescopeDescription, + reconstructor: Type[SKLearnReconstructor], + feature_names: list, + rng: np.random.Generator, + log=LOG, + n_events=None, +): + """Chunked loading of events for training ML models""" + chunk_iterator = loader.read_telescope_events_chunked( + chunk_size, + telescopes=[telescope_type], + ) + table = [] + n_events_in_file = 0 + n_valid_events_in_file = 0 + n_non_predictable = 0 + + for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): + log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + n_events_in_file += len(table_chunk) + + if isinstance(reconstructor, DispReconstructor): + if not np.all( + table_chunk["subarray_pointing_frame"] + == CoordinateFrameType.ALTAZ.value + ): + raise ValueError( + "Pointing information for training data" + " has to be provided in horizontal coordinates" + ) + + mask = reconstructor.quality_query.get_table_mask(table_chunk) + table_chunk = table_chunk[mask] + log.debug( + "Events in chunk %d after applying quality_query: %d", + chunk, + len(table_chunk), + ) + n_valid_events_in_file += len(table_chunk) + + table_chunk = reconstructor.feature_generator( + table_chunk, subarray=loader.subarray + ) + table_chunk = table_chunk[feature_names] + + valid = check_valid_rows(table_chunk) + if not np.all(valid): + n_non_predictable += np.sum(~valid) + table_chunk = table_chunk[valid] + + table.append(table_chunk) + + table = vstack(table) + log.info("Events read from input: %d", n_events_in_file) + log.info("Events after applying quality query: %d", n_valid_events_in_file) + + if len(table) == 0: + raise TooFewEvents( + f"No events after quality query for telescope type {telescope_type}" + ) + + if n_non_predictable > 0: + log.warning("Dropping %d non-predictable events.", n_non_predictable) + + if n_events is not None: + if n_events > len(table): + log.warning( + "Number of events in table (%d) is less" + " than requested number of events %d", + len(table), + n_events, + ) + else: + log.info("Sampling %d events", n_events) + idx = rng.choice(len(table), n_events, replace=False) + idx.sort() + table = table[idx] + + return table diff --git a/docs/changes/2423.optimization.rst b/docs/changes/2423.optimization.rst new file mode 100644 index 00000000000..b6e1567767a --- /dev/null +++ b/docs/changes/2423.optimization.rst @@ -0,0 +1,3 @@ +Load data and apply event and column selection in chunks in ``ctapipe-train-*`` +before merging afterwards. +This reduces memory usage.