Skip to content

Commit

Permalink
added logic if length of processing row is not == 1
Browse files Browse the repository at this point in the history
  • Loading branch information
kujaku11 committed Sep 15, 2023
1 parent 2d13cab commit 77c2a0c
Showing 1 changed file with 70 additions and 25 deletions.
95 changes: 70 additions & 25 deletions aurora/pipelines/transfer_function_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def initialize_mth5s(self, mode="r"):
remote station id: mth5.mth5.MTH5
"""

local_mth5_obj = initialize_mth5(self.config.stations.local.mth5_path, mode=mode)
local_mth5_obj = initialize_mth5(
self.config.stations.local.mth5_path, mode=mode
)
if self.config.stations.remote:
remote_path = self.config.stations.remote[0].mth5_path
remote_mth5_obj = initialize_mth5(remote_path, mode="r")
Expand All @@ -85,7 +87,7 @@ def initialize_mth5s(self, mode="r"):
self._mth5_objs = mth5_objs
return

def update_dataset_df(self,i_dec_level):
def update_dataset_df(self, i_dec_level):
"""
This function has two different modes. The first mode initializes values in the
array, and could be placed into TFKDataset.initialize_time_series_data()
Expand Down Expand Up @@ -132,7 +134,11 @@ def update_dataset_df(self,i_dec_level):
run_xrds = row["run_dataarray"].to_dataset("channel")
decimation = self.config.decimations[i_dec_level].decimation
decimated_xrds = prototype_decimate(decimation, run_xrds)
self.dataset_df["run_dataarray"].at[i] = decimated_xrds.to_array("channel") # See Note 1 above
self.dataset_df["run_dataarray"].at[
i
] = decimated_xrds.to_array(
"channel"
) # See Note 1 above

print("DATASET DF UPDATED")
return
Expand Down Expand Up @@ -177,7 +183,11 @@ def check_if_fc_levels_already_exist(self):
Modifies self.dataset_df inplace, assigning bools to the "fc" column
"""
groupby = ['survey', 'station_id', 'run_id',]
groupby = [
"survey",
"station_id",
"run_id",
]
grouper = self.processing_summary.groupby(groupby)

for (survey_id, station_id, run_id), df in grouper:
Expand All @@ -188,10 +198,12 @@ def check_if_fc_levels_already_exist(self):

if len(associated_run_sub_df) > 1:
# See Note #4
print("Warning -- not all runs will processed as a continuous chunk -- in future may need to loop over runlets to check for FCs")
print(

Check warning on line 201 in aurora/pipelines/transfer_function_kernel.py

View check run for this annotation

Codecov / codecov/patch

aurora/pipelines/transfer_function_kernel.py#L201

Added line #L201 was not covered by tests
"Warning -- not all runs will processed as a continuous chunk -- in future may need to loop over runlets to check for FCs"
)

dataset_df_indices = np.r_[associated_run_sub_df.index]
#dataset_df_indices = associated_run_sub_df.index.to_numpy()
# dataset_df_indices = associated_run_sub_df.index.to_numpy()
run_row = associated_run_sub_df.iloc[0]
row_ssr_str = f"survey: {run_row.survey}, station_id: {run_row.station_id}, run_id: {run_row.run_id}"

Expand All @@ -204,34 +216,48 @@ def check_if_fc_levels_already_exist(self):
print(msg)
self.dataset_df.loc[dataset_df_indices, "fc"] = False
else:
print("Prebuilt Fourier Coefficients detected -- checking if they satisfy processing requirements...")
print(
"Prebuilt Fourier Coefficients detected -- checking if they satisfy processing requirements..."
)
# Assume FC Groups are keyed by run_id, check if there is a relevant group
try:
fc_group = station_obj.fourier_coefficients_group.get_fc_group(run_id)
fc_group = (
station_obj.fourier_coefficients_group.get_fc_group(
run_id
)
)
except MTH5Error:
self.dataset_df.loc[dataset_df_indices, "fc"] = False
print(f"Run ID {run_id} not found in FC Groups, -- will need to build them ")
print(

Check warning on line 231 in aurora/pipelines/transfer_function_kernel.py

View check run for this annotation

Codecov / codecov/patch

aurora/pipelines/transfer_function_kernel.py#L231

Added line #L231 was not covered by tests
f"Run ID {run_id} not found in FC Groups, -- will need to build them "
)
continue

if len(fc_group.groups_list) < self.processing_config.num_decimation_levels:
if (
len(fc_group.groups_list)
< self.processing_config.num_decimation_levels
):
self.dataset_df.loc[dataset_df_indices, "fc"] = False
print(f"Not enough FC Groups available for {row_ssr_str} -- will need to build them ")
print(

Check warning on line 241 in aurora/pipelines/transfer_function_kernel.py

View check run for this annotation

Codecov / codecov/patch

aurora/pipelines/transfer_function_kernel.py#L241

Added line #L241 was not covered by tests
f"Not enough FC Groups available for {row_ssr_str} -- will need to build them "
)
continue

# Can check time periods here if desired, but unique (survey, station, run) should make this unneeded
# processing_run = self.processing_config.stations.local.get_run(run_id)
# for tp in processing_run.time_periods:
# assert tp in fc_group time periods


# See note #2
fcs_already_there = fc_group.supports_aurora_processing_config(self.processing_config,
run_row.remote)
self.dataset_df.loc[dataset_df_indices, "fc"] = fcs_already_there
fcs_already_there = fc_group.supports_aurora_processing_config(
self.processing_config, run_row.remote
)
self.dataset_df.loc[
dataset_df_indices, "fc"
] = fcs_already_there

return


def make_processing_summary(self):
"""
Melt the decimation levels over the run summary. Add columns to estimate
Expand All @@ -250,11 +276,15 @@ def make_processing_summary(self):
decimation_info = self.config.decimation_info()
for i_dec, dec_factor in decimation_info.items():
tmp[i_dec] = dec_factor
tmp = tmp.melt(id_vars=id_vars, value_name="dec_factor", var_name="dec_level")
tmp = tmp.melt(
id_vars=id_vars, value_name="dec_factor", var_name="dec_level"
)
sortby = ["survey", "station_id", "run_id", "start", "dec_level"]
tmp.sort_values(by=sortby, inplace=True)
tmp.reset_index(drop=True, inplace=True)
tmp.drop("sample_rate", axis=1, inplace=True) # not valid for decimated data
tmp.drop(
"sample_rate", axis=1, inplace=True
) # not valid for decimated data

# Add window info
group_by = [
Expand All @@ -269,11 +299,15 @@ def make_processing_summary(self):
print(group)
print(df)
try:
assert (df.dec_level.diff()[1:] == 1).all() # dec levels increment by 1
assert (
df.dec_level.diff()[1:] == 1
).all() # dec levels increment by 1
assert df.dec_factor.iloc[0] == 1
assert df.dec_level.iloc[0] == 0
except AssertionError:
raise AssertionError("Decimation levels not structured as expected")
raise AssertionError(

Check warning on line 308 in aurora/pipelines/transfer_function_kernel.py

View check run for this annotation

Codecov / codecov/patch

aurora/pipelines/transfer_function_kernel.py#L308

Added line #L308 was not covered by tests
"Decimation levels not structured as expected"
)
# df.sample_rate /= np.cumprod(df.dec_factor) # better to take from config
window_params_df = self.config.window_scheme(as_type="df")
df.reset_index(inplace=True, drop=True)
Expand Down Expand Up @@ -324,8 +358,14 @@ def validate_decimation_scheme_and_dataset_compatability(
"""
if min_num_stft_windows is None:
min_stft_window_info = {x.decimation.level: x.min_num_stft_windows for x in self.processing_config.decimations}
min_stft_window_list = [min_stft_window_info[x] for x in self.processing_summary.dec_level]
min_stft_window_info = {
x.decimation.level: x.min_num_stft_windows
for x in self.processing_config.decimations
}
min_stft_window_list = [
min_stft_window_info[x]
for x in self.processing_summary.dec_level
]
min_num_stft_windows = pd.Series(min_stft_window_list)

self.processing_summary["valid"] = (
Expand Down Expand Up @@ -373,8 +413,12 @@ def valid_decimations(self):
valid_levels = tmp.dec_level.unique()

dec_levels = [x for x in self.config.decimations]
dec_levels = [x for x in dec_levels if x.decimation.level in valid_levels]
print(f"After validation there are {len(dec_levels)} valid decimation levels")
dec_levels = [
x for x in dec_levels if x.decimation.level in valid_levels
]
print(
f"After validation there are {len(dec_levels)} valid decimation levels"
)
return dec_levels

def is_valid_dataset(self, row, i_dec):
Expand Down Expand Up @@ -404,7 +448,8 @@ def is_valid_dataset(self, row, i_dec):

cond = cond1 & cond2 & cond3 & cond4 & cond5
processing_row = self.processing_summary[cond]
assert len(processing_row) == 1
if len(processing_row) != 1:
return False

Check warning on line 452 in aurora/pipelines/transfer_function_kernel.py

View check run for this annotation

Codecov / codecov/patch

aurora/pipelines/transfer_function_kernel.py#L452

Added line #L452 was not covered by tests
is_valid = processing_row.valid.iloc[0]
return is_valid

Expand Down

0 comments on commit 77c2a0c

Please sign in to comment.