From 77c2a0c06b8d4fc30dd32e8aa4eee94adf4a8728 Mon Sep 17 00:00:00 2001 From: JP Date: Fri, 15 Sep 2023 14:54:07 -0700 Subject: [PATCH] added logic if length of processing row is not == 1 --- aurora/pipelines/transfer_function_kernel.py | 95 ++++++++++++++------ 1 file changed, 70 insertions(+), 25 deletions(-) diff --git a/aurora/pipelines/transfer_function_kernel.py b/aurora/pipelines/transfer_function_kernel.py index 4852ce0c..f0f5951e 100644 --- a/aurora/pipelines/transfer_function_kernel.py +++ b/aurora/pipelines/transfer_function_kernel.py @@ -72,7 +72,9 @@ def initialize_mth5s(self, mode="r"): remote station id: mth5.mth5.MTH5 """ - local_mth5_obj = initialize_mth5(self.config.stations.local.mth5_path, mode=mode) + local_mth5_obj = initialize_mth5( + self.config.stations.local.mth5_path, mode=mode + ) if self.config.stations.remote: remote_path = self.config.stations.remote[0].mth5_path remote_mth5_obj = initialize_mth5(remote_path, mode="r") @@ -85,7 +87,7 @@ def initialize_mth5s(self, mode="r"): self._mth5_objs = mth5_objs return - def update_dataset_df(self,i_dec_level): + def update_dataset_df(self, i_dec_level): """ This function has two different modes. The first mode initializes values in the array, and could be placed into TFKDataset.initialize_time_series_data() @@ -132,7 +134,11 @@ def update_dataset_df(self,i_dec_level): run_xrds = row["run_dataarray"].to_dataset("channel") decimation = self.config.decimations[i_dec_level].decimation decimated_xrds = prototype_decimate(decimation, run_xrds) - self.dataset_df["run_dataarray"].at[i] = decimated_xrds.to_array("channel") # See Note 1 above + self.dataset_df["run_dataarray"].at[ + i + ] = decimated_xrds.to_array( + "channel" + ) # See Note 1 above print("DATASET DF UPDATED") return @@ -177,7 +183,11 @@ def check_if_fc_levels_already_exist(self): Modifies self.dataset_df inplace, assigning bools to the "fc" column """ - groupby = ['survey', 'station_id', 'run_id',] + groupby = [ + "survey", + "station_id", + "run_id", + ] grouper = self.processing_summary.groupby(groupby) for (survey_id, station_id, run_id), df in grouper: @@ -188,10 +198,12 @@ def check_if_fc_levels_already_exist(self): if len(associated_run_sub_df) > 1: # See Note #4 - print("Warning -- not all runs will processed as a continuous chunk -- in future may need to loop over runlets to check for FCs") + print( + "Warning -- not all runs will processed as a continuous chunk -- in future may need to loop over runlets to check for FCs" + ) dataset_df_indices = np.r_[associated_run_sub_df.index] - #dataset_df_indices = associated_run_sub_df.index.to_numpy() + # dataset_df_indices = associated_run_sub_df.index.to_numpy() run_row = associated_run_sub_df.iloc[0] row_ssr_str = f"survey: {run_row.survey}, station_id: {run_row.station_id}, run_id: {run_row.run_id}" @@ -204,18 +216,31 @@ def check_if_fc_levels_already_exist(self): print(msg) self.dataset_df.loc[dataset_df_indices, "fc"] = False else: - print("Prebuilt Fourier Coefficients detected -- checking if they satisfy processing requirements...") + print( + "Prebuilt Fourier Coefficients detected -- checking if they satisfy processing requirements..." + ) # Assume FC Groups are keyed by run_id, check if there is a relevant group try: - fc_group = station_obj.fourier_coefficients_group.get_fc_group(run_id) + fc_group = ( + station_obj.fourier_coefficients_group.get_fc_group( + run_id + ) + ) except MTH5Error: self.dataset_df.loc[dataset_df_indices, "fc"] = False - print(f"Run ID {run_id} not found in FC Groups, -- will need to build them ") + print( + f"Run ID {run_id} not found in FC Groups, -- will need to build them " + ) continue - if len(fc_group.groups_list) < self.processing_config.num_decimation_levels: + if ( + len(fc_group.groups_list) + < self.processing_config.num_decimation_levels + ): self.dataset_df.loc[dataset_df_indices, "fc"] = False - print(f"Not enough FC Groups available for {row_ssr_str} -- will need to build them ") + print( + f"Not enough FC Groups available for {row_ssr_str} -- will need to build them " + ) continue # Can check time periods here if desired, but unique (survey, station, run) should make this unneeded @@ -223,15 +248,16 @@ def check_if_fc_levels_already_exist(self): # for tp in processing_run.time_periods: # assert tp in fc_group time periods - # See note #2 - fcs_already_there = fc_group.supports_aurora_processing_config(self.processing_config, - run_row.remote) - self.dataset_df.loc[dataset_df_indices, "fc"] = fcs_already_there + fcs_already_there = fc_group.supports_aurora_processing_config( + self.processing_config, run_row.remote + ) + self.dataset_df.loc[ + dataset_df_indices, "fc" + ] = fcs_already_there return - def make_processing_summary(self): """ Melt the decimation levels over the run summary. Add columns to estimate @@ -250,11 +276,15 @@ def make_processing_summary(self): decimation_info = self.config.decimation_info() for i_dec, dec_factor in decimation_info.items(): tmp[i_dec] = dec_factor - tmp = tmp.melt(id_vars=id_vars, value_name="dec_factor", var_name="dec_level") + tmp = tmp.melt( + id_vars=id_vars, value_name="dec_factor", var_name="dec_level" + ) sortby = ["survey", "station_id", "run_id", "start", "dec_level"] tmp.sort_values(by=sortby, inplace=True) tmp.reset_index(drop=True, inplace=True) - tmp.drop("sample_rate", axis=1, inplace=True) # not valid for decimated data + tmp.drop( + "sample_rate", axis=1, inplace=True + ) # not valid for decimated data # Add window info group_by = [ @@ -269,11 +299,15 @@ def make_processing_summary(self): print(group) print(df) try: - assert (df.dec_level.diff()[1:] == 1).all() # dec levels increment by 1 + assert ( + df.dec_level.diff()[1:] == 1 + ).all() # dec levels increment by 1 assert df.dec_factor.iloc[0] == 1 assert df.dec_level.iloc[0] == 0 except AssertionError: - raise AssertionError("Decimation levels not structured as expected") + raise AssertionError( + "Decimation levels not structured as expected" + ) # df.sample_rate /= np.cumprod(df.dec_factor) # better to take from config window_params_df = self.config.window_scheme(as_type="df") df.reset_index(inplace=True, drop=True) @@ -324,8 +358,14 @@ def validate_decimation_scheme_and_dataset_compatability( """ if min_num_stft_windows is None: - min_stft_window_info = {x.decimation.level: x.min_num_stft_windows for x in self.processing_config.decimations} - min_stft_window_list = [min_stft_window_info[x] for x in self.processing_summary.dec_level] + min_stft_window_info = { + x.decimation.level: x.min_num_stft_windows + for x in self.processing_config.decimations + } + min_stft_window_list = [ + min_stft_window_info[x] + for x in self.processing_summary.dec_level + ] min_num_stft_windows = pd.Series(min_stft_window_list) self.processing_summary["valid"] = ( @@ -373,8 +413,12 @@ def valid_decimations(self): valid_levels = tmp.dec_level.unique() dec_levels = [x for x in self.config.decimations] - dec_levels = [x for x in dec_levels if x.decimation.level in valid_levels] - print(f"After validation there are {len(dec_levels)} valid decimation levels") + dec_levels = [ + x for x in dec_levels if x.decimation.level in valid_levels + ] + print( + f"After validation there are {len(dec_levels)} valid decimation levels" + ) return dec_levels def is_valid_dataset(self, row, i_dec): @@ -404,7 +448,8 @@ def is_valid_dataset(self, row, i_dec): cond = cond1 & cond2 & cond3 & cond4 & cond5 processing_row = self.processing_summary[cond] - assert len(processing_row) == 1 + if len(processing_row) != 1: + return False is_valid = processing_row.valid.iloc[0] return is_valid