Skip to content

Commit

Permalink
Merge pull request #200 from ctlearn-project/real_data_dl2
Browse files Browse the repository at this point in the history
Transform back the predicted quantities into proper the space for real data
  • Loading branch information
TjarkMiener authored Jul 18, 2024
2 parents 2f2e91c + 7880056 commit 93f59b8
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
5 changes: 3 additions & 2 deletions ctlearn/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import tensorflow as tf
import astropy.units as u


class KerasBatchGenerator(tf.keras.utils.Sequence):
Expand Down Expand Up @@ -39,12 +40,12 @@ def __init__(
self.event_list, self.obs_list = [], []
# Labels
self.prt_pos, self.enr_pos, self.drc_pos = None, None, None
self.drc_unit = None
self.drc_unit = u.deg
self.prt_labels = []
self.enr_labels = []
self.az_labels, self.alt_labels, self.sep_labels = [], [], []
self.trgpatch_labels = []
self.energy_unit = None
self.energy_unit = "log(TeV)"

for i, desc in enumerate(self.DLDataReader.example_description):
if "HWtrigger" in desc["name"]:
Expand Down
52 changes: 49 additions & 3 deletions ctlearn/output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from astropy.table import Table
from astropy.coordinates import SkyCoord
import astropy.units as u

from ctapipe.io.pointing import PointingInterpolator

def write_output(h5file, data, rest_data, reader, predictions, tasks):
prediction_dir = h5file.replace(f'{h5file.split("/")[-1]}', "")
Expand Down Expand Up @@ -140,8 +140,54 @@ def write_output(h5file, data, rest_data, reader, predictions, tasks):
reco["true_alt"] = np.array(true_alt)
if "direction" in tasks:
if reader.fix_pointing is None:
reco["reco_az"] = np.array(predictions[:, 0])
reco["reco_alt"] = np.array(predictions[:, 1])
# Currently we only have LST-1 real observational data,
# so there is only one tel_id in the file.
# For stereo we should fix this directly in the ctapipe plugin,
# which is currently under development.
pointing_interpolator = PointingInterpolator()
tel_id_int = None
for tel_id, pointing_table in reader.telescope_pointings.items():
tel_id_int = int(tel_id.replace("tel_", ""))
pointing_interpolator.add_table(tel_id_int, pointing_table)
trigger_info = reader.tel_trigger_table[reader.tel_trigger_table["tel_id"]== tel_id_int]
# Check if the number of predictions and trigger info match
# Actually this check is redundant since the dl1dh do not allow quality cuts when processing real data
# However, it is still good to have it here in case table are not properly filled.
if len(predictions[:, 0]) != len(trigger_info):
raise ValueError(
f"The number of predictions ({len(predictions[:, 0])}) and trigger info ({len(trigger_info)}) do not match."
)
event_id, obs_id, tel_id = [], [], []
reco_az, reco_alt = [], []
pointing_az, pointing_alt, time = [], [], []
for i, (az_off, alt_off) in enumerate(zip(predictions[:, 0], predictions[:, 1])):
tel_alt, tel_az = pointing_interpolator(tel_id_int, trigger_info[i]['time'])
pointing = SkyCoord(
tel_az.to_value(data.drc_unit),
tel_alt.to_value(data.drc_unit),
frame="altaz",
unit="deg",
)
reco_direction = pointing.spherical_offsets_by(
u.Quantity(az_off, unit=u.deg),
u.Quantity(alt_off, unit=u.deg),
)
event_id.append(trigger_info[i]['event_id'])
obs_id.append(trigger_info[i]['obs_id'])
tel_id.append(trigger_info[i]['tel_id'])
time.append(trigger_info[i]['time'])
reco_az.append(reco_direction.az.to_value(data.drc_unit))
reco_alt.append(reco_direction.alt.to_value(data.drc_unit))
pointing_az.append(tel_az.to_value(u.deg))
pointing_alt.append(tel_alt.to_value(u.deg))
reco["event_id"] = np.array(event_id)
reco["obs_id"] = np.array(obs_id)
reco["tel_id"] = np.array(tel_id)
reco["time"] = np.array(time)
reco["reco_az"] = np.array(reco_az)
reco["reco_alt"] = np.array(reco_alt)
reco["pointing_az"] = np.array(pointing_az)
reco["pointing_alt"] = np.array(pointing_alt)
reco["reco_sep"] = np.array(predictions[:, 2])
else:
reco_az, reco_alt = [], []
Expand Down

0 comments on commit 93f59b8

Please sign in to comment.