Skip to content

Commit

Permalink
Chunked loading for training particle clf and training disp reco
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasBeiske committed Oct 26, 2023
1 parent 1fa5ea1 commit 6f38d22
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 38 deletions.
68 changes: 48 additions & 20 deletions ctapipe/tools/train_disp_reconstructor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import astropy.units as u
import numpy as np
from astropy.table import vstack

from ctapipe.core import Tool
from ctapipe.core.traits import Bool, Int, IntTelescopeParameter, Path
Expand Down Expand Up @@ -47,6 +48,12 @@ class TrainDispReconstructor(Tool):
),
).tag(config=True)

chunk_size = Int(
default_value=100000,
allow_none=True,
help="How many subarray events to load at once before training on n_events.",
).tag(config=True)

random_seed = Int(
default_value=0, help="Random seed for sampling and cross validation"
).tag(config=True)
Expand Down Expand Up @@ -112,33 +119,54 @@ def start(self):
self.log.info("done")

def _read_table(self, telescope_type):
table = self.loader.read_telescope_events([telescope_type])
self.log.info("Events read from input: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"Input file does not contain any events for telescope type {telescope_type}"
chunk_iterator = self.loader.read_telescope_events_chunked(
self.chunk_size,
telescopes=[telescope_type],
)
table = []
n_events_in_file = 0
n_valid_events_in_file = 0
n_non_predictable = 0

for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
n_events_in_file += len(table_chunk)

mask = self.models.quality_query.get_table_mask(table_chunk)
table_chunk = table_chunk[mask]
self.log.debug(
"Events in chunk %d after applying quality_query: %d",
chunk,
len(table_chunk),
)
n_valid_events_in_file += len(table_chunk)

mask = self.models.quality_query.get_table_mask(table)
table = table[mask]
self.log.info("Events after applying quality query: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
table_chunk = self.models.feature_generator(
table_chunk, subarray=self.loader.subarray
)
table_chunk[self.models.target] = self._get_true_disp(table_chunk)
# Add true energy for energy-dependent performance plots
columns = self.models.features + [self.models.target, "true_energy"]
table_chunk = table_chunk[columns]

table = self.models.feature_generator(table, subarray=self.loader.subarray)
valid = check_valid_rows(table_chunk)
if not np.all(valid):
n_non_predictable += np.sum(valid)
table_chunk = table_chunk[valid]

table[self.models.target] = self._get_true_disp(table)
table.append(table_chunk)

# Add true energy for energy-dependent performance plots
columns = self.models.features + [self.models.target, "true_energy"]
table = table[columns]
table = vstack(table)
self.log.info("Events read from input: %d", n_events_in_file)
self.log.info("Events after applying quality query: %d", n_valid_events_in_file)

if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
)

valid = check_valid_rows(table)
if not np.all(valid):
self.log.warning("Dropping non-predicable events.")
table = table[valid]
if n_non_predictable > 0:
self.log.warning("Dropping %d non-predictable events.", n_non_predictable)

n_events = self.n_events.tel[telescope_type]
if n_events is not None:
Expand Down
67 changes: 49 additions & 18 deletions ctapipe/tools/train_particle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ class TrainParticleClassifier(Tool):
),
).tag(config=True)

chunk_size = Int(
default_value=100000,
allow_none=True,
help=(
"How many subarray events to load at once before training on"
" n_signal and n_background events."
),
).tag(config=True)

random_seed = Int(
default_value=0,
help="Random number seed for sampling and the cross validation splitting",
Expand Down Expand Up @@ -162,31 +171,53 @@ def start(self):
self.log.info("done")

def _read_table(self, telescope_type, loader, n_events=None):
table = loader.read_telescope_events([telescope_type])
self.log.info("Events read from input: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"Input file does not contain any events for telescope type {telescope_type}"
chunk_iterator = loader.read_telescope_events_chunked(
self.chunk_size,
telescopes=[telescope_type],
)
table = []
n_events_in_file = 0
n_valid_events_in_file = 0
n_non_predictable = 0

for chunk, (_, _, table_chunk) in enumerate(chunk_iterator):
self.log.debug("Events read from chunk %d: %d", chunk, len(table_chunk))
n_events_in_file += len(table_chunk)

mask = self.classifier.quality_query.get_table_mask(table_chunk)
table_chunk = table_chunk[mask]
self.log.debug(
"Events in chunk %d after applying quality_query: %d",
chunk,
len(table_chunk),
)
n_valid_events_in_file += len(table_chunk)

table_chunk = self.classifier.feature_generator(
table_chunk, subarray=self.subarray
)
# Add true energy for energy-dependent performance plots
columns = self.classifier.features + [self.classifier.target, "true_energy"]
table_chunk = table_chunk[columns]

valid = check_valid_rows(table_chunk)
if not np.all(valid):
n_non_predictable += np.sum(valid)
table_chunk = table_chunk[valid]

table.append(table_chunk)

table = vstack(table)
self.log.info("Events read from input: %d", n_events_in_file)
self.log.info("Events after applying quality query: %d", n_valid_events_in_file)

mask = self.classifier.quality_query.get_table_mask(table)
table = table[mask]
self.log.info("Events after applying quality query: %d", len(table))
if len(table) == 0:
raise TooFewEvents(
f"No events after quality query for telescope type {telescope_type}"
)

table = self.classifier.feature_generator(table, subarray=self.subarray)

# Add true energy for energy-dependent performance plots
columns = self.classifier.features + [self.classifier.target, "true_energy"]
table = table[columns]

valid = check_valid_rows(table)
if not np.all(valid):
self.log.warning("Dropping non-predictable events.")
table = table[valid]
if n_non_predictable > 0:
self.log.warning("Dropping %d non-predictable events.", n_non_predictable)

if n_events is not None:
if n_events > len(table):
Expand Down

0 comments on commit 6f38d22

Please sign in to comment.