Skip to content

Commit

Permalink
change config and test(mostly completed)
Browse files Browse the repository at this point in the history
  • Loading branch information
forestbat committed Jan 30, 2024
1 parent ee2001e commit 7099df9
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 47 deletions.
8 changes: 4 additions & 4 deletions scripts/conf/v002.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ data_cfgs:

model_cfgs:
model_name: "SPPLSTM2"
weight_dir: "test_data/models/model_v20.pth"
weight_dir: "test_data/models/best_model.pth"
model_hyperparam:
seq_length: 168
forecast_length: 24
Expand All @@ -50,10 +50,10 @@ training_cfgs:
continue_train: False

test_cfgs:
metrics: ['NSE', 'KGE']
metrics: ['NSE', 'KGE', 'Bias', 'RMSE']

train_period: ["2019-01-01", "2019-01-31"]
test_period: ["2019-01-01", "2019-01-31"]
train_period: ["2017-07-01", "2017-09-29"]
test_period: ["2017-07-01", "2017-09-29"]

var_out: ["streamflow"]
var_t: ["tp"]
113 changes: 79 additions & 34 deletions scripts/postprocess/model_stream.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,109 @@
# pytest model_stream.py::test_auto_stream
import json
import os.path
import pathlib
import pathlib as pl
import smtplib
from email.mime.text import MIMEText
import hydrodataset as hds
from email.mime.multipart import MIMEMultipart

import geopandas as gpd
import hydrodataset as hds
import intake as itk
import numpy as np
import pandas as pd
import s3fs
import urllib3 as ur
import xarray as xr
import yaml
from scipy import signal
from torchhydro.configs.config import default_config_file, update_cfg, cmd
from torchhydro.datasets.data_dict import data_sources_dict
from torchhydro.trainers.deep_hydro import DeepHydro
from torchhydro.trainers.trainer import set_random_seed
from xarray import Dataset
from yaml import load, Loader

from yaml import load, Loader, Dumper

work_dir = pl.Path(os.path.abspath(os.curdir)).parent.parent
with open(os.path.join(work_dir, 'test_data/privacy_config.yml'), 'r') as fp:
private_str = fp.read()
private_yml = yaml.load(private_str, Loader)
storage_option = {'key': private_yml['minio']['access_key'], 'secret': private_yml['minio']['secret'],
'client_kwargs': {'endpoint_url': private_yml['minio']['client_endpoint']}}
mc_fs = s3fs.S3FileSystem(endpoint_url=storage_option['client_kwargs']['endpoint_url'],
key=storage_option['key'], secret=storage_option['secret'])


def test_auto_stream():
test_config_path = os.path.join(work_dir, 'scripts/conf/v002.yml')
with open(test_config_path, 'r+') as fp:
test_conf_yml = yaml.load(fp, Loader)
# 配置文件中的weight_dir应与模型保存位置相对应
# 配置文件中的weight_dir应与模型保存位置相对应,目前模型路径是直接指定,而非选择最新
# test_model_name = test_read_history(user_model_type='model', version='300')
eval_log, preds_xr, obss_xr = run_normal_dl(test_config_path)
with open('eval_log.json', mode='a+', encoding='utf-8') as fp:
fp.seek(0)
last_eval_log = json.load(fp)
compare_history_report(eval_log, last_eval_log)
'''
fp.write(test_conf_yml['data_cfgs']['sampler'] + '\n')
fp.write(test_conf_yml['model_cfgs']['model_name'] + '\n')
fp.write(test_conf_yml['model_cfgs']['model_hyperparam'] + '\n')
'''
eval_log['sampler'] = test_conf_yml['data_cfgs']['sampler']
eval_log['model_name'] = test_conf_yml['model_cfgs']['model_name']
eval_log['model_hyperparam'] = test_conf_yml['model_cfgs']['model_hyperparam']
# json.dump(eval_log, fp)
preds_xr_sf_np = preds_xr['streamflow'].to_numpy().T
obss_xr_sf_np = obss_xr['streamflow'].to_numpy().T
eval_log['Metrics'] = {}
eval_log['Config'] = {}
eval_log['Basin'] = obss_xr['basin'].to_numpy().tolist()
eval_log['Metrics']['NSE'] = eval_log['NSE of streamflow'].tolist()
eval_log.pop('NSE of streamflow')
eval_log['Metrics']['MAE'] = eval_log['Bias of streamflow'].tolist()
eval_log.pop('Bias of streamflow')
eval_log['Metrics']['KGE'] = eval_log['KGE of streamflow'].tolist()
eval_log.pop('KGE of streamflow')
eval_log['Metrics']['RMSE'] = eval_log['RMSE of streamflow'].tolist()
eval_log.pop('RMSE of streamflow')
eval_log['Metrics']['Bias of peak height(mm/h)'] = {}
eval_log['Metrics']['Bias of peak appearance(h)'] = {}
eval_log['Reports'] = {}
eval_log['Reports']['Total streamflow(mm/h)'] = {}
eval_log['Reports']['Peak rainfall(mm)'] = {}
eval_log['Reports']['Peak streamflow(mm/h)'] = {}
eval_log['Reports']['Streamflow peak appearance'] = {}
for i in range(0, preds_xr_sf_np.shape[0]):
basin = obss_xr['basin'].to_numpy()[i]
pred_peaks_index = signal.argrelmax(preds_xr_sf_np[i])
pred_peaks_time = (preds_xr['time_now'].to_numpy())[pred_peaks_index]
obs_peaks_index = signal.argrelmax(obss_xr_sf_np[i])
obss_peaks_time = (obss_xr['time_now'].to_numpy())[obs_peaks_index]
eval_log['Metrics']['Bias of peak height(mm/h)'][basin] = np.mean([abs(obss_xr_sf_np[i] - preds_xr_sf_np[i])
for i in
range(0, len(obs_peaks_index))]).tolist()

eval_log['Metrics']['Bias of peak appearance(h)'][basin] = np.mean([abs(obss_peaks_time[i] - pred_peaks_time[i])
for i in range(0,
len(obss_peaks_time))]).tolist() / 3.6e12
# 在这里是所有预测值在[0,forecast_length]内的总洪量
eval_log['Reports']['Total streamflow(mm/h)'][basin] = np.sum(
preds_xr_sf_np[i][0: test_conf_yml['model_cfgs']['model_hyperparam']['forecast_length']]).tolist()
# rainfall对于这个模型是输入先验值,地位“微妙”,找不到合适地点插入, 暂且留空
eval_log['Reports']['Peak rainfall(mm)'][basin] = 200
eval_log['Reports']['Peak streamflow(mm/h)'][basin] = np.max(
preds_xr_sf_np[i][0: test_conf_yml['model_cfgs']['model_hyperparam']['forecast_length']]).tolist()
eval_log['Reports']['Streamflow peak appearance'][basin] = np.datetime_as_string(pred_peaks_time,
unit='s').tolist()
eval_log['Config']['model_name'] = test_conf_yml['model_cfgs']['model_name']
eval_log['Config']['model_hyperparam'] = test_conf_yml['model_cfgs']['model_hyperparam']
eval_log['Config']['weight_path'] = test_conf_yml['model_cfgs']['weight_dir']
eval_log['Config']['t_range_train'] = test_conf_yml['train_period']
eval_log['Config']['t_range_test'] = test_conf_yml['test_period']
eval_log['Config']['dataset'] = test_conf_yml['data_cfgs']['dataset']
eval_log['Config']['sampler'] = test_conf_yml['data_cfgs']['sampler']
eval_log['Config']['scaler'] = test_conf_yml['data_cfgs']['scaler']
# https://zhuanlan.zhihu.com/p/631317974
send_address = private_yml['email']['send_address']
password = private_yml['email']['authenticate_code']
server = smtplib.SMTP_SSL('smtp.qq.com', 465)
login_result = server.login(send_address, password)
if login_result == (235, b'Authentication successful'):
content = str(eval_log)
content = yaml.dump(data=eval_log, Dumper=Dumper)
# https://service.mail.qq.com/detail/124/995
msg = MIMEText(content, 'plain', 'utf-8')
# https://stackoverflow.com/questions/58223773/send-a-list-of-dictionaries-formatted-with-indents-as-a-string-through-email-u
msg = MIMEMultipart()
msg['From'] = 'nickname<' + send_address + '>'
msg['To'] = str(['nickname<' + addr + '>;' for addr in private_yml['email']['to_address']])
msg['Subject'] = 'model_report'
msg.attach(MIMEText(content, 'plain'))
server.sendmail(send_address, private_yml['email']['to_address'], msg.as_string())
print('发送成功')
else:
Expand Down Expand Up @@ -93,10 +137,6 @@ def test_read_history(user_model_type='wasted', version='1'):


def test_read_valid_data(minio_obj_array, need_cache=False):
storage_option = {'key': private_yml['minio']['access_key'], 'secret': private_yml['minio']['secret'],
'client_kwargs': {'endpoint_url': private_yml['minio']['client_endpoint']}}
mc_fs = s3fs.S3FileSystem(endpoint_url=storage_option['client_kwargs']['endpoint_url'],
key=storage_option['key'], secret=storage_option['secret'])
# https://intake.readthedocs.io/en/latest/plugin-directory.html
data_obj_array = []
for obj in minio_obj_array:
Expand Down Expand Up @@ -159,6 +199,7 @@ def read_yaml(version):
return conf_yaml


'''
def compare_history_report(new_eval_log, old_eval_log):
if old_eval_log is None:
old_eval_log = {'NSE of streamflow': 0, 'KGE of streamflow': 0}
Expand All @@ -178,6 +219,7 @@ def compare_history_report(new_eval_log, old_eval_log):
new_eval_log['review'] = '白改了,下次再说吧'
else:
new_eval_log['review'] = '和上次相等,还需要再提高'
'''


def custom_cfg(
Expand All @@ -186,19 +228,21 @@ def custom_cfg(
f = open(cfgs_path, encoding="utf-8")
cfgs = yaml.load(f.read(), Loader=yaml.FullLoader)
config_data = default_config_file()
'''
remote_obj_array = ['1_02051500.nc', '86_21401550.nc', 'camelsus_attributes.nc', 'merge_streamflow.nc']
bucket_name = 'forestbat-private'
folder_prefix = 'predicate_data'
minio_obj_list = ['s3://'+bucket_name+'/'+folder_prefix+'/'+i for i in remote_obj_array]
minio_obj_list = ['s3://' + bucket_name + '/' + folder_prefix + '/' + i for i in remote_obj_array]
test_data_list = test_read_valid_data(minio_obj_list)
'''
args = cmd(
sub=cfgs["data_cfgs"]["sub"],
source=cfgs["data_cfgs"]["source"],
source_region=cfgs["data_cfgs"]["source_region"],
source_path=hds.ROOT_DIR,
streamflow_source_path=test_data_list[3],
rainfall_source_path=test_data_list[0:2],
attributes_path=test_data_list[2],
streamflow_source_path=os.path.join(hds.ROOT_DIR, 'merge_streamflow.nc'),
rainfall_source_path=hds.ROOT_DIR,
attributes_path=os.path.join(hds.ROOT_DIR, 'camelsus_attributes.nc'),
gfs_source_path="",
download=0,
ctx=cfgs["data_cfgs"]["ctx"],
Expand All @@ -214,7 +258,7 @@ def custom_cfg(
"out_channels": 8
},
weight_path=os.path.join(pathlib.Path(os.path.abspath(os.curdir)).parent.parent,
'test_data/models/best_model.pth'),
cfgs['model_cfgs']['weight_dir']),
loss_func=cfgs["training_cfgs"]["loss_func"],
sampler=cfgs["data_cfgs"]["sampler"],
dataset=cfgs["data_cfgs"]["dataset"],
Expand All @@ -224,6 +268,7 @@ def custom_cfg(
var_c=cfgs['data_cfgs']['constant_cols'],
var_out=["streamflow"],
# train_period=train_period,
# test_period的dict和拼接数据的periods存在一定抵触
test_period=[
{"start": "2017-07-01", "end": "2017-09-29"},
], # 该范围为降水的时间范围,流量会整体往后推24h
Expand All @@ -239,10 +284,10 @@ def custom_cfg(
endpoint_url=private_yml['minio']['server_url'],
access_key=private_yml['minio']['access_key'],
secret_key=private_yml['minio']['secret'],
bucket_name=bucket_name,
folder_prefix=folder_prefix,
# bucket_name=bucket_name,
# folder_prefix=folder_prefix,
# stat_dict_file=os.path.join(train_path, "GPM_GFS_Scaler_2_stat.json"),
user='yyy'
user='zxw'
)
update_cfg(config_data, args)
random_seed = config_data["training_cfgs"]["random_seed"]
Expand All @@ -252,7 +297,7 @@ def custom_cfg(
data_source = data_sources_dict[data_source_name](
data_cfgs["data_path"], data_cfgs["download"]
)
return data_source, config_data #, minio_obj_list
return data_source, config_data # , minio_obj_list


def run_normal_dl(cfg_path):
Expand All @@ -261,4 +306,4 @@ def run_normal_dl(cfg_path):
# preds_xr.to_netcdf(os.path.join("results", "v002_test", "preds.nc"))
# obss_xr.to_netcdf(os.path.join("results", "v002_test", "obss.nc"))
# print(eval_log)
return eval_log, preds_xr, obss_xr
return eval_log, preds_xr, obss_xr
19 changes: 10 additions & 9 deletions scripts/preprocess/gfs_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def generate_forecast_times_updated(date_str, hour_str, num):

# Define the forecasting hours
forecast_hours = [0, 6, 12, 18]

# Find the closest forecast hour before the given hour
closest_forecast_hour = max([hour for hour in forecast_hours if hour <= given_hour])

Expand Down Expand Up @@ -43,11 +43,11 @@ def generate_forecast_times_updated(date_str, hour_str, num):
# Combining both functions to fetch the latest data points

def fetch_latest_data(
date_np = np.datetime64("2017-01-01"),
time_str = "00",
bbbox = (-125, 25, -66, 50),
num = 3
):
date_np=np.datetime64("2017-01-01"),
time_str="00",
bbbox=(-125, 25, -66, 50),
num=3
):
forecast_times = generate_forecast_times_updated(date_np, time_str, num)
gfs_reader = minio.GFSReader()
time = forecast_times[0]
Expand Down Expand Up @@ -79,16 +79,17 @@ def fetch_latest_data(
data = data.rename({'valid_time': 'time'})
latest_data = xr.concat([latest_data, data], dim='time')
# print(latest_data)

latest_data = latest_data.to_dataset()
latest_data = latest_data.transpose('time', 'lon', 'lat')
# print(latest_data)
return latest_data


# Testing the combined function
# mask = xr.open_dataset('/home/xushuolong1/flood_data_preprocess/GPM_data_preprocess/mask_GFS/05584500.nc')
mask = xr.open_dataset(path_to_your_nc_file)
box = (mask.coords["lon"][0], mask.coords["lat"][0],mask.coords["lon"][-1], mask.coords["lat"][-1])
test_data = fetch_latest_data(date_np = "2017-01-01", time_str = "23", bbbox = box, num = 3)
box = (mask.coords["lon"][0], mask.coords["lat"][0], mask.coords["lon"][-1], mask.coords["lat"][-1])
test_data = fetch_latest_data(date_np="2017-01-01", time_str="23", bbbox=box, num=3)
# print(test_data)
test_data.to_netcdf('test_data.nc')

0 comments on commit 7099df9

Please sign in to comment.