Skip to content

Commit

Permalink
change config and test(primarily completed)
Browse files Browse the repository at this point in the history
  • Loading branch information
forestbat committed Jan 12, 2024
1 parent 98f75a9 commit ee2001e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 23 deletions.
22 changes: 18 additions & 4 deletions scripts/conf/v002.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,26 @@ data_cfgs:
sampler: "WuSampler"
scaler: "GPM_GFS_Scaler"
rolling: False
"constant_cols": [
"area",
"ele_mt_smn",
"slp_dg_sav",
"sgr_dk_sav",
"for_pc_sse",
"glc_cl_smj",
"run_mm_syr",
"inu_pc_slt",
"cmi_ix_syr",
"aet_mm_syr",
"snw_pc_syr",
"swc_pc_syr",
"gwt_cm_sav",
"cly_pc_sav",
"dor_pc_pva"
]

model_cfgs:
model_name: "SPPLSTM"
model_name: "SPPLSTM2"
weight_dir: "test_data/models/model_v20.pth"
model_hyperparam:
seq_length: 168
Expand All @@ -38,8 +55,5 @@ test_cfgs:
train_period: ["2019-01-01", "2019-01-31"]
test_period: ["2019-01-01", "2019-01-31"]

gage_id:
- '21401550'

var_out: ["streamflow"]
var_t: ["tp"]
4 changes: 4 additions & 0 deletions scripts/postprocess/eval_log.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"NSE of streamflow": [0, 0],
"KGE of streamflow": [0, 0]
}
60 changes: 42 additions & 18 deletions scripts/postprocess/model_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,24 @@

def test_auto_stream():
test_config_path = os.path.join(work_dir, 'scripts/conf/v002.yml')
with open(test_config_path, 'r+') as fp:
test_conf_yml = yaml.load(fp, Loader)
# 配置文件中的weight_dir应与模型保存位置相对应
test_model_name = test_read_history(user_model_type='model', version='300')
# test_model_name = test_read_history(user_model_type='model', version='300')
eval_log, preds_xr, obss_xr = run_normal_dl(test_config_path)
with open('eval_log.json', mode='a+') as fp:
with open('eval_log.json', mode='a+', encoding='utf-8') as fp:
fp.seek(0)
last_eval_log = json.load(fp)
compare_history_report(eval_log, last_eval_log)
json.dump(eval_log, fp)
compare_history_report(eval_log, last_eval_log)
'''
fp.write(test_conf_yml['data_cfgs']['sampler'] + '\n')
fp.write(test_conf_yml['model_cfgs']['model_name'] + '\n')
fp.write(test_conf_yml['model_cfgs']['model_hyperparam'] + '\n')
'''
eval_log['sampler'] = test_conf_yml['data_cfgs']['sampler']
eval_log['model_name'] = test_conf_yml['model_cfgs']['model_name']
eval_log['model_hyperparam'] = test_conf_yml['model_cfgs']['model_hyperparam']
# json.dump(eval_log, fp)
# https://zhuanlan.zhihu.com/p/631317974
send_address = private_yml['email']['send_address']
password = private_yml['email']['authenticate_code']
Expand Down Expand Up @@ -153,17 +164,17 @@ def compare_history_report(new_eval_log, old_eval_log):
old_eval_log = {'NSE of streamflow': 0, 'KGE of streamflow': 0}
# https://doi.org/10.1016/j.envsoft.2019.05.001
# 需要再算一下洪量
if (new_eval_log['NSE of streamflow'] > old_eval_log['NSE of streamflow']) & (
new_eval_log['KGE of streamflow'] > old_eval_log['KGE of streamflow']):
if (list(new_eval_log['NSE of streamflow']) > old_eval_log['NSE of streamflow']) & (
list(new_eval_log['KGE of streamflow']) > old_eval_log['KGE of streamflow']):
new_eval_log['review'] = '比上次更好些,再接再厉'
elif (new_eval_log['NSE of streamflow'] > old_eval_log['NSE of streamflow']) & (
new_eval_log['KGE of streamflow'] < old_eval_log['KGE of streamflow']):
elif (list(new_eval_log['NSE of streamflow']) > old_eval_log['NSE of streamflow']) & (
list(new_eval_log['KGE of streamflow']) < old_eval_log['KGE of streamflow']):
new_eval_log['review'] = '拟合比以前更好,但KGE下降,对洪峰预报可能有问题'
elif (new_eval_log['NSE of streamflow'] < old_eval_log['NSE of streamflow']) & (
new_eval_log['KGE of streamflow'] > old_eval_log['KGE of streamflow']):
elif (list(new_eval_log['NSE of streamflow']) < old_eval_log['NSE of streamflow']) & (
list(new_eval_log['KGE of streamflow']) > old_eval_log['KGE of streamflow']):
new_eval_log['review'] = '拟合结果更差了,问题在哪里?KGE更好一些,也许并没有那么差'
elif (new_eval_log['NSE of streamflow'] < old_eval_log['NSE of streamflow']) & (
new_eval_log['KGE of streamflow'] < old_eval_log['KGE of streamflow']):
elif (list(new_eval_log['NSE of streamflow']) < old_eval_log['NSE of streamflow']) & (
list(new_eval_log['KGE of streamflow']) < old_eval_log['KGE of streamflow']):
new_eval_log['review'] = '白改了,下次再说吧'
else:
new_eval_log['review'] = '和上次相等,还需要再提高'
Expand All @@ -188,19 +199,30 @@ def custom_cfg(
streamflow_source_path=test_data_list[3],
rainfall_source_path=test_data_list[0:2],
attributes_path=test_data_list[2],
gfs_source_path="",
download=0,
ctx=cfgs["data_cfgs"]["ctx"],
model_name=cfgs["model_cfgs"]["model_name"],
model_hyperparam=cfgs["model_cfgs"]["model_hyperparam"],
model_hyperparam={
"seq_length": 168,
"forecast_length": 24,
"n_output": 1,
"n_hidden_states": 60,
"dropout": 0.25,
"len_c": 15,
"in_channels": 1,
"out_channels": 8
},
weight_path=os.path.join(pathlib.Path(os.path.abspath(os.curdir)).parent.parent,
'test_data/models/model_v20.pth'),
'test_data/models/best_model.pth'),
loss_func=cfgs["training_cfgs"]["loss_func"],
sampler=cfgs["data_cfgs"]["sampler"],
dataset=cfgs["data_cfgs"]["dataset"],
scaler=cfgs["data_cfgs"]["scaler"],
batch_size=cfgs["training_cfgs"]["batch_size"],
var_t=cfgs["var_t"],
var_out=cfgs["var_out"],
var_t=[["tp"]],
var_c=cfgs['data_cfgs']['constant_cols'],
var_out=["streamflow"],
# train_period=train_period,
test_period=[
{"start": "2017-07-01", "end": "2017-09-29"},
Expand All @@ -209,7 +231,7 @@ def custom_cfg(
train_epoch=cfgs["training_cfgs"]["train_epoch"],
save_epoch=cfgs["training_cfgs"]["save_epoch"],
te=cfgs["training_cfgs"]["te"],
gage_id=["86_21401550"],
gage_id=["1_02051500", "86_21401550"],
which_first_tensor=cfgs["training_cfgs"]["which_first_tensor"],
continue_train=cfgs["training_cfgs"]["continue_train"],
rolling=cfgs['data_cfgs']['rolling'],
Expand All @@ -219,6 +241,8 @@ def custom_cfg(
secret_key=private_yml['minio']['secret'],
bucket_name=bucket_name,
folder_prefix=folder_prefix,
# stat_dict_file=os.path.join(train_path, "GPM_GFS_Scaler_2_stat.json"),
user='yyy'
)
update_cfg(config_data, args)
random_seed = config_data["training_cfgs"]["random_seed"]
Expand All @@ -228,7 +252,7 @@ def custom_cfg(
data_source = data_sources_dict[data_source_name](
data_cfgs["data_path"], data_cfgs["download"]
)
return data_source, config_data, minio_obj_list
return data_source, config_data #, minio_obj_list


def run_normal_dl(cfg_path):
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"streamflow": [0.166602310282088, 0.18204515446545455, 0.1759662282540064, 0.006084235666331094], "tp": [-1.0, -1.0, -0.9896670288640039, 0.07317114182316072]}
{"streamflow": [-0.7678212555626731, -0.5347280763404353, -0.6475110721431854, 0.10803312304617622], "tp": [-1.0, -0.5265558045505923, -0.8989121995323321, 0.28418674274170364], "area": [1519.0592455152516, 2031.0667666779414, 1775.0630060965964, 320.0047007266811], "ele_mt_smn": [81.6679512929293, 105.59209248628093, 93.63002188960512, 14.952588245844765], "slp_dg_sav": [26.459792408086415, 101.46837049742417, 63.964081452755295, 46.880361305836104], "sgr_dk_sav": [35.52689407713153, 69.89437715929068, 52.710635618211114, 21.479676926349473], "for_pc_sse": [53.0068040009987, 88.54204820640355, 70.77442610370113, 22.209527628378034], "glc_cl_smj": [2.4, 5.6, 4.0, 2.0], "run_mm_syr": [236.21287726392976, 393.08950353227567, 314.6511903981027, 98.04789141771619], "inu_pc_slt": [5.326787696585166, 20.690931772026122, 13.008859734305643, 9.602590047150597], "cmi_ix_syr": [-10.698390398541736, -9.585836128343313, -10.142113263442525, 0.6953464188740135], "aet_mm_syr": [653.4992435816932, 875.5488701554125, 764.5240568685529, 138.78101660857453], "snw_pc_syr": [0.5761347714110969, 2.030249462735507, 1.303192117073302, 0.9088216820777564], "swc_pc_syr": [77.83598587846146, 80.16464770789122, 79.00031679317634, 1.455413643393598], "gwt_cm_sav": [140.19146029443505, 379.8152494666633, 260.0033548805492, 149.76486823264267], "cly_pc_sav": [20.006731173500388, 20.58023720927662, 20.293484191388504, 0.358441272360146], "dor_pc_pva": [190.10000000000002, 1710.9, 950.5, 950.5]}

0 comments on commit ee2001e

Please sign in to comment.