Skip to content

Commit

Permalink
Multiple runs now entered into TF XML
Browse files Browse the repository at this point in the history
Add a method to KernelDataset to extract run info, looping over runs.
Also, noticed that some synthetic tests were commented out, fixed this.
Also, tidied some code in process_mth5.

[Issue(s): #181]
  • Loading branch information
kkappler committed Jun 25, 2022
1 parent caae96e commit e362296
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 38 deletions.
45 changes: 14 additions & 31 deletions aurora/pipelines/process_mth5.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,10 @@ def process_mth5(
#see notes labelled with ToDo TFK above

#Assign additional columns to dataset_df, populate with mth5_objs
all_mth5_objs = len(dataset_df) * [None]
mth5_obj_column = len(dataset_df) * [None]
for i, station_id in enumerate(dataset_df["station_id"]):
all_mth5_objs[i] = mth5_objs[station_id]
dataset_df["mth5_obj"] = all_mth5_objs
mth5_obj_column[i] = mth5_objs[station_id]
dataset_df["mth5_obj"] = mth5_obj_column
dataset_df["run"] = None
dataset_df["run_dataarray"] = None
dataset_df["stft"] = None
Expand All @@ -388,9 +388,12 @@ def process_mth5(
for i,row in dataset_df.iterrows():
run_xrts = row["run_dataarray"].to_dataset("channel")
run_obj = row["run"]
station_id = row.station_id
stft_obj = make_stft_objects(processing_config, i_dec_level, run_obj,
run_xrts, units, station_id)
stft_obj = make_stft_objects(processing_config,
i_dec_level,
run_obj,
run_xrts,
units,
row.station_id)

if row.station_id == processing_config.stations.local.id:
local_stfts.append(stft_obj)
Expand All @@ -403,7 +406,7 @@ def process_mth5(
# Could mute bad FCs here - Not implemented yet.
# RETURN FC_OBJECT

if processing_config.stations.remote:#reference_station_id:
if processing_config.stations.remote:
remote_merged_stft_obj = xr.concat(remote_stfts, "time")
else:
remote_merged_stft_obj = None
Expand Down Expand Up @@ -435,30 +438,10 @@ def process_mth5(
close_mths_objs(dataset_df)
return tf_collection
else:
# intended to be the default in future

#See ISSUE #181: Uncomment this once we have a mature multi-run test
# #tfk_dataset.get_station_metadata_for_tf_archive()
# #get a list of local runs:
# cond1 = dataset_df["station_id"]==processing_config.stations.local.id
# sub_df = dataset_df[cond1]
# #sanity check:
# run_ids = sub_df.run_id.unique()
# assert(len(run_ids) == len(sub_df))
# # iterate over these runs, packing metadata into
# station_metadata = None
# for i,row in sub_df.iterrows():
# local_run_obj = row.run
# if station_metadata is None:
# station_metadata = local_run_obj.station_group.metadata
# station_metadata._runs = []
# run_metadata = local_run_obj.metadata
# station_metadata.add_run(run_metadata)

station_metadata = local_run_obj.station_group.metadata
station_metadata._runs = []
run_metadata = local_run_obj.metadata
station_metadata.add_run(run_metadata)
# intended to be the default in future (return tf_cls, not tf_collection)

local_station_id = processing_config.stations.local.id
station_metadata = tfk_dataset.get_station_metadata(local_station_id)

# Need to create an issue for this as well
if len(mth5_objs) == 1:
Expand Down
32 changes: 32 additions & 0 deletions aurora/tf_kernel/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,38 @@ def is_remote_reference(self):
def restrict_run_intervals_to_simultaneous(self):
raise NotImplementedError

def get_station_metadata(self, local_station_id):
"""
Helper function for archiving the TF
Parameters
----------
local_station_id: str
The name of the local station
Returns
-------
"""
#get a list of local runs:
cond = self.df["station_id"] == local_station_id
sub_df = self.df[cond]

#sanity check:
run_ids = sub_df.run_id.unique()
assert(len(run_ids) == len(sub_df))

# iterate over these runs, packing metadata into
station_metadata = None
for i,row in sub_df.iterrows():
local_run_obj = row.run
if station_metadata is None:
station_metadata = local_run_obj.station_group.metadata
station_metadata._runs = []
run_metadata = local_run_obj.metadata
station_metadata.add_run(run_metadata)
return station_metadata



def restrict_to_station_list(df, station_ids, inplace=True):
Expand Down
14 changes: 7 additions & 7 deletions tests/synthetic/test_synthetic_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ def test_process_mth5():
# process_synthetic_1_underdetermined()
# process_synthetic_1_with_nans()

# z_file_path=AURORA_RESULTS_PATH.joinpath("syn1.zss")
# tfc = process_synthetic_1(z_file_path=z_file_path)
# z_file_path=AURORA_RESULTS_PATH.joinpath("syn1_scaled.zss")
# tfc = process_synthetic_1(z_file_path=z_file_path, test_scale_factor=True)
# z_file_path=AURORA_RESULTS_PATH.joinpath("syn1_simultaneous_estimate.zss")
# tfc = process_synthetic_1(z_file_path=z_file_path,
# test_simultaneous_regression=True)
z_file_path=AURORA_RESULTS_PATH.joinpath("syn1.zss")
tfc = process_synthetic_1(z_file_path=z_file_path)
z_file_path=AURORA_RESULTS_PATH.joinpath("syn1_scaled.zss")
tfc = process_synthetic_1(z_file_path=z_file_path, test_scale_factor=True)
z_file_path=AURORA_RESULTS_PATH.joinpath("syn1_simultaneous_estimate.zss")
tfc = process_synthetic_1(z_file_path=z_file_path,
test_simultaneous_regression=True)
tfc = process_synthetic_2()
tfc = process_synthetic_rr12()

Expand Down

0 comments on commit e362296

Please sign in to comment.