Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform back the predicted quantities into proper the space for real data #200

Merged
merged 4 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ctlearn/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import tensorflow as tf
import astropy.units as u


class KerasBatchGenerator(tf.keras.utils.Sequence):
Expand Down Expand Up @@ -39,12 +40,12 @@ def __init__(
self.event_list, self.obs_list = [], []
# Labels
self.prt_pos, self.enr_pos, self.drc_pos = None, None, None
self.drc_unit = None
self.drc_unit = u.deg
self.prt_labels = []
self.enr_labels = []
self.az_labels, self.alt_labels, self.sep_labels = [], [], []
self.trgpatch_labels = []
self.energy_unit = None
self.energy_unit = "log(TeV)"

for i, desc in enumerate(self.DLDataReader.example_description):
if "HWtrigger" in desc["name"]:
Expand Down
52 changes: 49 additions & 3 deletions ctlearn/output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from astropy.table import Table
from astropy.coordinates import SkyCoord
import astropy.units as u

from ctapipe.io.pointing import PointingInterpolator

def write_output(h5file, data, rest_data, reader, predictions, tasks):
prediction_dir = h5file.replace(f'{h5file.split("/")[-1]}', "")
Expand Down Expand Up @@ -140,8 +140,54 @@ def write_output(h5file, data, rest_data, reader, predictions, tasks):
reco["true_alt"] = np.array(true_alt)
if "direction" in tasks:
if reader.fix_pointing is None:
reco["reco_az"] = np.array(predictions[:, 0])
reco["reco_alt"] = np.array(predictions[:, 1])
# Currently we only have LST-1 real observational data,
# so there is only one tel_id in the file.
# For stereo we should fix this directly in the ctapipe plugin,
# which is currently under development.
pointing_interpolator = PointingInterpolator()
tel_id_int = None
for tel_id, pointing_table in reader.telescope_pointings.items():
tel_id_int = int(tel_id.replace("tel_", ""))
pointing_interpolator.add_table(tel_id_int, pointing_table)
trigger_info = reader.tel_trigger_table[reader.tel_trigger_table["tel_id"]== tel_id_int]
# Check if the number of predictions and trigger info match
# Actually this check is redundant since the dl1dh do not allow quality cuts when processing real data
# However, it is still good to have it here in case table are not properly filled.
if len(predictions[:, 0]) != len(trigger_info):
raise ValueError(
f"The number of predictions ({len(predictions[:, 0])}) and trigger info ({len(trigger_info)}) do not match."
)
event_id, obs_id, tel_id = [], [], []
reco_az, reco_alt = [], []
pointing_az, pointing_alt, time = [], [], []
for i, (az_off, alt_off) in enumerate(zip(predictions[:, 0], predictions[:, 1])):
tel_alt, tel_az = pointing_interpolator(tel_id_int, trigger_info[i]['time'])
pointing = SkyCoord(
tel_az.to_value(data.drc_unit),
tel_alt.to_value(data.drc_unit),
frame="altaz",
unit="deg",
)
reco_direction = pointing.spherical_offsets_by(
u.Quantity(az_off, unit=u.deg),
u.Quantity(alt_off, unit=u.deg),
)
event_id.append(trigger_info[i]['event_id'])
obs_id.append(trigger_info[i]['obs_id'])
tel_id.append(trigger_info[i]['tel_id'])
time.append(trigger_info[i]['time'])
reco_az.append(reco_direction.az.to_value(data.drc_unit))
reco_alt.append(reco_direction.alt.to_value(data.drc_unit))
pointing_az.append(tel_az.to_value(u.deg))
pointing_alt.append(tel_alt.to_value(u.deg))
reco["event_id"] = np.array(event_id)
reco["obs_id"] = np.array(obs_id)
reco["tel_id"] = np.array(tel_id)
reco["time"] = np.array(time)
reco["reco_az"] = np.array(reco_az)
reco["reco_alt"] = np.array(reco_alt)
reco["pointing_az"] = np.array(pointing_az)
reco["pointing_alt"] = np.array(pointing_alt)
reco["reco_sep"] = np.array(predictions[:, 2])
else:
reco_az, reco_alt = [], []
Expand Down