diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 40cd5e0f96..d23e8a8e8e 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -30,6 +30,16 @@ v23.2.1 spline table's basis function indices has been separated from the spline table evaluation which provides an accelerated PDF evaluation. +- The core.dataset.Dataset.get_aux_data method got the new optional argument + "default" to specify a default return value if the auxiliary data is not + defined for that dataset. + +- The core.signal_generator.MCMultiDatasetSignalGenerator get the new optional + constructor argument "valid_event_field_ranges_dict_list=None" to specify + valid event field ranges, e.g. for the declination. If generated signal events + do not match these valid event field ranges, those signal events will be + redrawn. + v23.2.0 ======= - Complete overhaul of SkyLLH for more generic handling of parameters diff --git a/skyllh/core/dataset.py b/skyllh/core/dataset.py index c284695d1f..4420ce96b0 100644 --- a/skyllh/core/dataset.py +++ b/skyllh/core/dataset.py @@ -2256,6 +2256,7 @@ def add_aux_data( def get_aux_data( self, name, + default=None, ): """Retrieves the auxiliary data that is stored in this data set under the given name. @@ -2264,6 +2265,9 @@ def get_aux_data( ---------- name : str The name under which the auxiliary data is stored. + default : any | None + If not ``None``, it specifies the returned default value when the + auxiliary data does not exists. Returns ------- @@ -2273,13 +2277,16 @@ def get_aux_data( Raises ------ KeyError - If no auxiliary data is stored with the given name. + If no auxiliary data is stored with the given name and no default + value was specified. """ name = str_cast( name, 'The name argument must be cast-able to type str!') if name not in self._aux_data: + if default is not None: + return default raise KeyError( f'The auxiliary data "{name}" is not defined for dataset ' f'"{self.name}"!') diff --git a/skyllh/core/signal_generator.py b/skyllh/core/signal_generator.py index be1a267a0a..e58371bace 100644 --- a/skyllh/core/signal_generator.py +++ b/skyllh/core/signal_generator.py @@ -20,6 +20,7 @@ Livetime, ) from skyllh.core.py import ( + classname, issequenceof, float_cast, int_cast, @@ -180,7 +181,7 @@ def generate_signal_events( n_signal : int The number of generated signal events. signal_events_dict : dict of DataFieldRecordArray - The dictionary holding the DataFieldRecordArray instancs with the + The dictionary holding the DataFieldRecordArray instances with the generated signal events. Each key of this dictionary represents the dataset index for which the signal events have been generated. """ @@ -366,11 +367,11 @@ def generate_signal_events( mean = rss.random.poisson( float_cast( mean, - 'The mean argument must be castable to type of float!')) + 'The mean argument must be cast-able to type of float!')) n_signal = int_cast( mean, - 'The mean argument must be castable to type of int!') + 'The mean argument must be cast-able to type of int!') src_detsigyield_weights_service =\ self.ds_sig_weight_factors_service.src_detsigyield_weights_service @@ -431,6 +432,7 @@ def __init__( shg_mgr, dataset_list, data_list, + valid_event_field_ranges_dict_list=None, **kwargs, ): """Constructs a new signal generator instance. @@ -446,8 +448,11 @@ def __init__( data_list : list of DatasetData instances The list of DatasetData instances holding the actual data of each dataset. The order must match the order of ``dataset_list``. - kwargs - A typical keyword argument is the instance of MultiDatasetTCLLHRatio. + valid_event_field_ranges_dict_list : list of dict | None + If not ``None``, it specifies for each dataset event fields (key) + and their valid value range as a 2-element tuple (value). If a + generated signal event does not fall into a given field range, the + signal event will be discarded and a new signal event will be drawn. """ super().__init__( shg_mgr=shg_mgr, @@ -455,8 +460,57 @@ def __init__( data_list=data_list, **kwargs) + if valid_event_field_ranges_dict_list is None: + valid_event_field_ranges_dict_list = [dict()]*len(self.dataset_list) + if not isinstance(valid_event_field_ranges_dict_list, list): + raise TypeError( + 'The `valid_event_field_ranges_dict_list` argument must be a list.' + ) + if len(valid_event_field_ranges_dict_list) != len(self.dataset_list): + raise ValueError( + 'The valid_event_field_ranges_dict_list argument must be a ' + f'list of length {len(self.dataset_list)}, but it is of length ' + f'{len(valid_event_field_ranges_dict_list)}!') + self.valid_event_field_ranges_dict_list =\ + valid_event_field_ranges_dict_list + self._construct_signal_candidates() + @property + def valid_event_field_ranges_dict_list(self): + """The list of dictionary holding the event data fields (key) and their + valid value range as 2-element tuple (value). + """ + return self._valid_event_field_ranges_dict_list + + @valid_event_field_ranges_dict_list.setter + def valid_event_field_ranges_dict_list(self, dict_list): + if not isinstance(dict_list, list): + raise TypeError( + 'The valid_event_field_ranges_dict_list must be an instance of ' + 'list! ' + f'Its current type is {classname(dict_list)}!') + for d in dict_list: + for (k, v) in d.items(): + if not isinstance(k, str): + raise TypeError( + 'Each key of the dictionary of the ' + 'valid_event_field_ranges_dict property must be an ' + 'instance of str! ' + f'But the type of one of the keys is {classname(k)}!') + if not isinstance(v, tuple): + raise TypeError( + 'Each value of the dictionary of the ' + 'valid_event_field_ranges_dict property must be an ' + 'instance of tuple! ' + f'But the value type for the event field "{k}" is ' + f'{classname(v)}!') + if len(v) != 2: + raise ValueError( + f'The tuple for the event field {k} must be of length ' + f'2! Its current length is {len(v)}!') + self._valid_event_field_ranges_dict_list = dict_list + def _construct_signal_candidates(self): """Constructs an array holding pointer information of signal candidate events pointing into the real MC dataset(s). @@ -529,6 +583,121 @@ def _construct_signal_candidates(self): items=self._sig_candidates, probabilities=self._sig_candidates['weight']) + def _get_invalid_events_mask( + self, + events, + valid_event_field_ranges_dict, + ): + """Determines a boolean mask to select invalid events, which do not + fulfill the given valid event field ranges. + + Parameters + ---------- + events : instance of DataFieldRecordArray + The instance of DataFieldRecordArray of length N_events holding the + events to check. + valid_event_field_ranges_dict : dict + The dictionary holding the data field names (key) and their valid + value ranges (value). + + Raises + ------ + KeyError + If one of the event field does not exist in ``events``. + + Returns + ------- + mask : instance of numpy.ndarray + The (N_events,)-shaped numpy.ndarray of bool, holding the mask of + the invalid events. + """ + mask = np.zeros((len(events),), dtype=np.bool_) + + for (field_name, min_max) in valid_event_field_ranges_dict.items(): + if field_name not in events: + raise KeyError( + f'The event data field "{field_name}" specified in the ' + 'valid_event_field_ranges_dict does not exist in the event ' + 'data!') + field_values = events[field_name] + mask |= (field_values < min_max[0]) | (field_values > min_max[1]) + + return mask + + def _draw_valid_sig_events_for_dataset_and_shg( + self, + rss, + mc, + n_signal, + ds_idx, + valid_event_field_ranges_dict, + shg, + shg_idx, + ): + """Draws n_signal valid signal events for the given dataset and source + hypothesis group. + + Signal events will be drawn until all events match the event field + ranges, i.e. are valid signal events. + + Parameters + ---------- + rss : instance of RandomStateService + The instance of RandomStateService which should be used to draw + random numbers from. + mc : instance of DataFieldRecordArray + The instance of DataFieldRecordArray holding the monte-carlo events. + n_signal : int + The number of signal events to draw. + ds_idx : int + The index of the dataset. + valid_event_field_ranges_dict : dict + The dictionary holding the data field names (key) and their valid + value ranges (value) for the requested dataset. + shg : instance of SourceHypothesisGroup + The instance of SourceHypothesisGroup for which signal events should + get drawn. + shg_idx : int + The index of the source hypothesis group. + + Returns + ------- + sig_events : instance of DataFieldRecordArray + The instance of DataFieldRecordArray holding the drawn valid signal + events. + """ + sig_events = None + + n = 0 + while n < n_signal: + events_meta = self._sig_candidates_random_choice( + rss=rss, + size=n_signal-n, + ) + m = (events_meta['ds_idx'] == ds_idx) &\ + (events_meta['shg_idx'] == shg_idx) + events = mc[events_meta['ev_idx'][m]] + if len(events) > 0: + events = shg.sig_gen_method.\ + signal_event_post_sampling_processing( + shg, events_meta, events) + + valid_events_mask = np.invert( + self._get_invalid_events_mask( + events, + valid_event_field_ranges_dict, + ) + ) + events = events[valid_events_mask] + if len(events) > 0: + if sig_events is None: + sig_events = events + else: + sig_events.append(events) + n = len(sig_events) + + return sig_events + def change_shg_mgr( self, shg_mgr): @@ -643,11 +812,11 @@ def generate_signal_events( mean = rss.random.poisson( float_cast( mean, - 'The mean argument must be castable to type of float!')) + 'The mean argument must be cast-able to type of float!')) n_signal = int_cast( mean, - 'The mean argument must be castable to type of int!') + 'The mean argument must be cast-able to type of int!') # Draw n_signal signal candidates according to their weight. sig_events_meta = self._sig_candidates_random_choice( @@ -665,6 +834,8 @@ def generate_signal_events( signal_events_dict = dict() ds_idxs = np.unique(sig_events_meta['ds_idx']) for ds_idx in ds_idxs: + valid_event_field_ranges_dict =\ + self.valid_event_field_ranges_dict_list[ds_idx] mc = self._data_list[ds_idx].mc ds_mask = sig_events_meta['ds_idx'] == ds_idx n_sig_events_ds = np.count_nonzero(ds_mask) @@ -701,6 +872,27 @@ def generate_signal_events( signal_event_post_sampling_processing( shg, shg_sig_events_meta, shg_sig_events) + # Determine the signal events, which do not fulfill the valid + # event field ranges for this dataset. + invalid_events_mask = self._get_invalid_events_mask( + shg_sig_events, + valid_event_field_ranges_dict) + n_redraw_events = np.count_nonzero(invalid_events_mask) + if n_redraw_events > 0: + # Re-draw n_redraw_events signal events for this dataset + # and SHG. + redrawn_shg_sig_events =\ + self._draw_valid_sig_events_for_dataset_and_shg( + rss=rss, + mc=mc, + n_signal=n_redraw_events, + ds_idx=ds_idx, + valid_event_field_ranges_dict=valid_event_field_ranges_dict, + shg=shg, + shg_idx=shg_idx, + ) + shg_sig_events[invalid_events_mask] = redrawn_shg_sig_events + indices = np.indices((n_shg_sig_events,))[0] + fill_start_idx sig_events.set_selection(indices, shg_sig_events)