Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix168 #182

Merged
merged 8 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ v23.2.1
spline table's basis function indices has been separated from the spline table
evaluation which provides an accelerated PDF evaluation.

- The core.dataset.Dataset.get_aux_data method got the new optional argument
"default" to specify a default return value if the auxiliary data is not
defined for that dataset.

- The core.signal_generator.MCMultiDatasetSignalGenerator get the new optional
constructor argument "valid_event_field_ranges_dict_list=None" to specify
valid event field ranges, e.g. for the declination. If generated signal events
do not match these valid event field ranges, those signal events will be
redrawn.

v23.2.0
=======
- Complete overhaul of SkyLLH for more generic handling of parameters
Expand Down
9 changes: 8 additions & 1 deletion skyllh/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,6 +2256,7 @@ def add_aux_data(
def get_aux_data(
self,
name,
default=None,
):
"""Retrieves the auxiliary data that is stored in this data set under
the given name.
Expand All @@ -2264,6 +2265,9 @@ def get_aux_data(
----------
name : str
The name under which the auxiliary data is stored.
default : any | None
If not ``None``, it specifies the returned default value when the
auxiliary data does not exists.

Returns
-------
Expand All @@ -2273,13 +2277,16 @@ def get_aux_data(
Raises
------
KeyError
If no auxiliary data is stored with the given name.
If no auxiliary data is stored with the given name and no default
value was specified.
"""
name = str_cast(
name,
'The name argument must be cast-able to type str!')

if name not in self._aux_data:
if default is not None:
return default
raise KeyError(
f'The auxiliary data "{name}" is not defined for dataset '
f'"{self.name}"!')
Expand Down
206 changes: 199 additions & 7 deletions skyllh/core/signal_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Livetime,
)
from skyllh.core.py import (
classname,
issequenceof,
float_cast,
int_cast,
Expand Down Expand Up @@ -180,7 +181,7 @@ def generate_signal_events(
n_signal : int
The number of generated signal events.
signal_events_dict : dict of DataFieldRecordArray
The dictionary holding the DataFieldRecordArray instancs with the
The dictionary holding the DataFieldRecordArray instances with the
generated signal events. Each key of this dictionary represents the
dataset index for which the signal events have been generated.
"""
Expand Down Expand Up @@ -366,11 +367,11 @@ def generate_signal_events(
mean = rss.random.poisson(
float_cast(
mean,
'The mean argument must be castable to type of float!'))
'The mean argument must be cast-able to type of float!'))

n_signal = int_cast(
mean,
'The mean argument must be castable to type of int!')
'The mean argument must be cast-able to type of int!')

src_detsigyield_weights_service =\
self.ds_sig_weight_factors_service.src_detsigyield_weights_service
Expand Down Expand Up @@ -431,6 +432,7 @@ def __init__(
shg_mgr,
dataset_list,
data_list,
valid_event_field_ranges_dict_list=None,
**kwargs,
):
"""Constructs a new signal generator instance.
Expand All @@ -446,17 +448,69 @@ def __init__(
data_list : list of DatasetData instances
The list of DatasetData instances holding the actual data of each
dataset. The order must match the order of ``dataset_list``.
kwargs
A typical keyword argument is the instance of MultiDatasetTCLLHRatio.
valid_event_field_ranges_dict_list : list of dict | None
If not ``None``, it specifies for each dataset event fields (key)
and their valid value range as a 2-element tuple (value). If a
generated signal event does not fall into a given field range, the
signal event will be discarded and a new signal event will be drawn.
"""
super().__init__(
shg_mgr=shg_mgr,
dataset_list=dataset_list,
data_list=data_list,
**kwargs)

if valid_event_field_ranges_dict_list is None:
valid_event_field_ranges_dict_list = [dict()]*len(self.dataset_list)
if not isinstance(valid_event_field_ranges_dict_list, list):
raise TypeError(
'The `valid_event_field_ranges_dict_list` argument must be a list.'
)
if len(valid_event_field_ranges_dict_list) != len(self.dataset_list):
raise ValueError(
'The valid_event_field_ranges_dict_list argument must be a '
f'list of length {len(self.dataset_list)}, but it is of length '
f'{len(valid_event_field_ranges_dict_list)}!')
self.valid_event_field_ranges_dict_list =\
valid_event_field_ranges_dict_list

self._construct_signal_candidates()

@property
def valid_event_field_ranges_dict_list(self):
"""The list of dictionary holding the event data fields (key) and their
valid value range as 2-element tuple (value).
"""
return self._valid_event_field_ranges_dict_list

@valid_event_field_ranges_dict_list.setter
def valid_event_field_ranges_dict_list(self, dict_list):
if not isinstance(dict_list, list):
raise TypeError(
'The valid_event_field_ranges_dict_list must be an instance of '
'list! '
f'Its current type is {classname(dict_list)}!')
for d in dict_list:
for (k, v) in d.items():
if not isinstance(k, str):
raise TypeError(
'Each key of the dictionary of the '
'valid_event_field_ranges_dict property must be an '
'instance of str! '
f'But the type of one of the keys is {classname(k)}!')
if not isinstance(v, tuple):
raise TypeError(
'Each value of the dictionary of the '
'valid_event_field_ranges_dict property must be an '
'instance of tuple! '
f'But the value type for the event field "{k}" is '
f'{classname(v)}!')
if len(v) != 2:
raise ValueError(
f'The tuple for the event field {k} must be of length '
f'2! Its current length is {len(v)}!')
self._valid_event_field_ranges_dict_list = dict_list

def _construct_signal_candidates(self):
"""Constructs an array holding pointer information of signal candidate
events pointing into the real MC dataset(s).
Expand Down Expand Up @@ -529,6 +583,121 @@ def _construct_signal_candidates(self):
items=self._sig_candidates,
probabilities=self._sig_candidates['weight'])

def _get_invalid_events_mask(
self,
events,
valid_event_field_ranges_dict,
):
"""Determines a boolean mask to select invalid events, which do not
fulfill the given valid event field ranges.

Parameters
----------
events : instance of DataFieldRecordArray
The instance of DataFieldRecordArray of length N_events holding the
events to check.
valid_event_field_ranges_dict : dict
The dictionary holding the data field names (key) and their valid
value ranges (value).

Raises
------
KeyError
If one of the event field does not exist in ``events``.

Returns
-------
mask : instance of numpy.ndarray
The (N_events,)-shaped numpy.ndarray of bool, holding the mask of
the invalid events.
"""
mask = np.zeros((len(events),), dtype=np.bool_)

for (field_name, min_max) in valid_event_field_ranges_dict.items():
if field_name not in events:
raise KeyError(
f'The event data field "{field_name}" specified in the '
'valid_event_field_ranges_dict does not exist in the event '
'data!')
field_values = events[field_name]
mask |= (field_values < min_max[0]) | (field_values > min_max[1])

return mask

def _draw_valid_sig_events_for_dataset_and_shg(
self,
rss,
mc,
n_signal,
ds_idx,
valid_event_field_ranges_dict,
shg,
shg_idx,
):
"""Draws n_signal valid signal events for the given dataset and source
hypothesis group.

Signal events will be drawn until all events match the event field
ranges, i.e. are valid signal events.

Parameters
----------
rss : instance of RandomStateService
The instance of RandomStateService which should be used to draw
random numbers from.
mc : instance of DataFieldRecordArray
The instance of DataFieldRecordArray holding the monte-carlo events.
n_signal : int
The number of signal events to draw.
ds_idx : int
The index of the dataset.
valid_event_field_ranges_dict : dict
The dictionary holding the data field names (key) and their valid
value ranges (value) for the requested dataset.
shg : instance of SourceHypothesisGroup
The instance of SourceHypothesisGroup for which signal events should
get drawn.
shg_idx : int
The index of the source hypothesis group.

Returns
-------
sig_events : instance of DataFieldRecordArray
The instance of DataFieldRecordArray holding the drawn valid signal
events.
"""
sig_events = None

n = 0
while n < n_signal:
events_meta = self._sig_candidates_random_choice(
rss=rss,
size=n_signal-n,
)
m = (events_meta['ds_idx'] == ds_idx) &\
(events_meta['shg_idx'] == shg_idx)
events = mc[events_meta['ev_idx'][m]]
if len(events) > 0:
events = shg.sig_gen_method.\
signal_event_post_sampling_processing(
shg, events_meta, events)

valid_events_mask = np.invert(
self._get_invalid_events_mask(
events,
valid_event_field_ranges_dict,
)
)
events = events[valid_events_mask]
if len(events) > 0:
if sig_events is None:
sig_events = events
else:
sig_events.append(events)
n = len(sig_events)

return sig_events

def change_shg_mgr(
self,
shg_mgr):
Expand Down Expand Up @@ -643,11 +812,11 @@ def generate_signal_events(
mean = rss.random.poisson(
float_cast(
mean,
'The mean argument must be castable to type of float!'))
'The mean argument must be cast-able to type of float!'))

n_signal = int_cast(
mean,
'The mean argument must be castable to type of int!')
'The mean argument must be cast-able to type of int!')

# Draw n_signal signal candidates according to their weight.
sig_events_meta = self._sig_candidates_random_choice(
Expand All @@ -665,6 +834,8 @@ def generate_signal_events(
signal_events_dict = dict()
ds_idxs = np.unique(sig_events_meta['ds_idx'])
for ds_idx in ds_idxs:
valid_event_field_ranges_dict =\
self.valid_event_field_ranges_dict_list[ds_idx]
mc = self._data_list[ds_idx].mc
ds_mask = sig_events_meta['ds_idx'] == ds_idx
n_sig_events_ds = np.count_nonzero(ds_mask)
Expand Down Expand Up @@ -701,6 +872,27 @@ def generate_signal_events(
signal_event_post_sampling_processing(
shg, shg_sig_events_meta, shg_sig_events)

# Determine the signal events, which do not fulfill the valid
# event field ranges for this dataset.
invalid_events_mask = self._get_invalid_events_mask(
shg_sig_events,
valid_event_field_ranges_dict)
n_redraw_events = np.count_nonzero(invalid_events_mask)
if n_redraw_events > 0:
# Re-draw n_redraw_events signal events for this dataset
# and SHG.
redrawn_shg_sig_events =\
self._draw_valid_sig_events_for_dataset_and_shg(
rss=rss,
mc=mc,
n_signal=n_redraw_events,
ds_idx=ds_idx,
valid_event_field_ranges_dict=valid_event_field_ranges_dict,
shg=shg,
shg_idx=shg_idx,
)
shg_sig_events[invalid_events_mask] = redrawn_shg_sig_events

indices = np.indices((n_shg_sig_events,))[0] + fill_start_idx
sig_events.set_selection(indices, shg_sig_events)

Expand Down