From 6f38d22c873aee0eef5791deb012d1724b2e4a2b Mon Sep 17 00:00:00 2001 From: Lukas Beiske Date: Thu, 26 Oct 2023 16:14:35 +0200 Subject: [PATCH] Chunked loading for training particle clf and training disp reco --- ctapipe/tools/train_disp_reconstructor.py | 68 +++++++++++++++------- ctapipe/tools/train_particle_classifier.py | 67 +++++++++++++++------ 2 files changed, 97 insertions(+), 38 deletions(-) diff --git a/ctapipe/tools/train_disp_reconstructor.py b/ctapipe/tools/train_disp_reconstructor.py index 1803b59e60d..b5843770cd7 100644 --- a/ctapipe/tools/train_disp_reconstructor.py +++ b/ctapipe/tools/train_disp_reconstructor.py @@ -1,5 +1,6 @@ import astropy.units as u import numpy as np +from astropy.table import vstack from ctapipe.core import Tool from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path @@ -47,6 +48,12 @@ class TrainDispReconstructor(Tool): ), ).tag(config=True) + chunk_size = Int( + default_value=100000, + allow_none=True, + help="How many subarray events to load at once before training on n_events.", + ).tag(config=True) + random_seed = Int( default_value=0, help="Random seed for sampling and cross validation" ).tag(config=True) @@ -112,33 +119,54 @@ def start(self): self.log.info("done") def _read_table(self, telescope_type): - table = self.loader.read_telescope_events([telescope_type]) - self.log.info("Events read from input: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"Input file does not contain any events for telescope type {telescope_type}" + chunk_iterator = self.loader.read_telescope_events_chunked( + self.chunk_size, + telescopes=[telescope_type], + ) + table = [] + n_events_in_file = 0 + n_valid_events_in_file = 0 + n_non_predictable = 0 + + for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): + self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + n_events_in_file += len(table_chunk) + + mask = self.models.quality_query.get_table_mask(table_chunk) + table_chunk = table_chunk[mask] + self.log.debug( + "Events in chunk %d after applying quality_query: %d", + chunk, + len(table_chunk), ) + n_valid_events_in_file += len(table_chunk) - mask = self.models.quality_query.get_table_mask(table) - table = table[mask] - self.log.info("Events after applying quality query: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"No events after quality query for telescope type {telescope_type}" + table_chunk = self.models.feature_generator( + table_chunk, subarray=self.loader.subarray ) + table_chunk[self.models.target] = self._get_true_disp(table_chunk) + # Add true energy for energy-dependent performance plots + columns = self.models.features + [self.models.target, "true_energy"] + table_chunk = table_chunk[columns] - table = self.models.feature_generator(table, subarray=self.loader.subarray) + valid = check_valid_rows(table_chunk) + if not np.all(valid): + n_non_predictable += np.sum(valid) + table_chunk = table_chunk[valid] - table[self.models.target] = self._get_true_disp(table) + table.append(table_chunk) - # Add true energy for energy-dependent performance plots - columns = self.models.features + [self.models.target, "true_energy"] - table = table[columns] + table = vstack(table) + self.log.info("Events read from input: %d", n_events_in_file) + self.log.info("Events after applying quality query: %d", n_valid_events_in_file) + + if len(table) == 0: + raise TooFewEvents( + f"No events after quality query for telescope type {telescope_type}" + ) - valid = check_valid_rows(table) - if not np.all(valid): - self.log.warning("Dropping non-predicable events.") - table = table[valid] + if n_non_predictable > 0: + self.log.warning("Dropping %d non-predictable events.", n_non_predictable) n_events = self.n_events.tel[telescope_type] if n_events is not None: diff --git a/ctapipe/tools/train_particle_classifier.py b/ctapipe/tools/train_particle_classifier.py index 21c5e21fdd3..d3d1d8074db 100644 --- a/ctapipe/tools/train_particle_classifier.py +++ b/ctapipe/tools/train_particle_classifier.py @@ -78,6 +78,15 @@ class TrainParticleClassifier(Tool): ), ).tag(config=True) + chunk_size = Int( + default_value=100000, + allow_none=True, + help=( + "How many subarray events to load at once before training on" + " n_signal and n_background events." + ), + ).tag(config=True) + random_seed = Int( default_value=0, help="Random number seed for sampling and the cross validation splitting", @@ -162,31 +171,53 @@ def start(self): self.log.info("done") def _read_table(self, telescope_type, loader, n_events=None): - table = loader.read_telescope_events([telescope_type]) - self.log.info("Events read from input: %d", len(table)) - if len(table) == 0: - raise TooFewEvents( - f"Input file does not contain any events for telescope type {telescope_type}" + chunk_iterator = loader.read_telescope_events_chunked( + self.chunk_size, + telescopes=[telescope_type], + ) + table = [] + n_events_in_file = 0 + n_valid_events_in_file = 0 + n_non_predictable = 0 + + for chunk, (_, _, table_chunk) in enumerate(chunk_iterator): + self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk)) + n_events_in_file += len(table_chunk) + + mask = self.classifier.quality_query.get_table_mask(table_chunk) + table_chunk = table_chunk[mask] + self.log.debug( + "Events in chunk %d after applying quality_query: %d", + chunk, + len(table_chunk), ) + n_valid_events_in_file += len(table_chunk) + + table_chunk = self.classifier.feature_generator( + table_chunk, subarray=self.subarray + ) + # Add true energy for energy-dependent performance plots + columns = self.classifier.features + [self.classifier.target, "true_energy"] + table_chunk = table_chunk[columns] + + valid = check_valid_rows(table_chunk) + if not np.all(valid): + n_non_predictable += np.sum(valid) + table_chunk = table_chunk[valid] + + table.append(table_chunk) + + table = vstack(table) + self.log.info("Events read from input: %d", n_events_in_file) + self.log.info("Events after applying quality query: %d", n_valid_events_in_file) - mask = self.classifier.quality_query.get_table_mask(table) - table = table[mask] - self.log.info("Events after applying quality query: %d", len(table)) if len(table) == 0: raise TooFewEvents( f"No events after quality query for telescope type {telescope_type}" ) - table = self.classifier.feature_generator(table, subarray=self.subarray) - - # Add true energy for energy-dependent performance plots - columns = self.classifier.features + [self.classifier.target, "true_energy"] - table = table[columns] - - valid = check_valid_rows(table) - if not np.all(valid): - self.log.warning("Dropping non-predictable events.") - table = table[valid] + if n_non_predictable > 0: + self.log.warning("Dropping %d non-predictable events.", n_non_predictable) if n_events is not None: if n_events > len(table):