-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathforecast_to_xarr.py
378 lines (323 loc) · 11.1 KB
/
forecast_to_xarr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# Standard library
import os
# Third-party
import numpy as np
import torch
import xarray as xa
from tqdm import tqdm
# First-party
from neural_lam import constants
from neural_lam.models.graph_efm import GraphEFM
FC_DIR_PATH = "saved_forecasts"
def get_var_dims(save_ensemble):
"""
Get dimension names for atmospheric and surface variables
"""
atm_dims = (
"time",
"prediction_timedelta",
"longitude",
"latitude",
"level",
)
sur_dims = (
"time",
"prediction_timedelta",
"longitude",
"latitude",
)
if save_ensemble:
atm_dims = ("realization",) + atm_dims
sur_dims = ("realization",) + sur_dims
return atm_dims, sur_dims
def forecast_to_xds(
forecast_tensor,
batch_init_times,
coords,
var_filter_list,
level_filter_list,
time_enc_unit,
):
"""
Turn a pytorch tensor representing a forecast into a saveable xarray.Dataset
forecast_tensor: (B, (S), pred_steps, num_grid_nodes, d_f)
"""
# Figure out if ensemble (S-dimension exists)
save_ensemble = len(forecast_tensor.shape) == 5
full_fc = forecast_tensor.cpu().numpy()
# (B, (S), pred_steps, num_grid_nodes, d_f)
# Now on CPU, numpy
# Reshape to grid shape
# Note: this reshape works with or without S-dimension
full_fc_grid = full_fc.reshape(
*full_fc.shape[:-2],
*constants.GRID_SHAPE,
full_fc.shape[-1],
) # (B, (S), pred_steps, num_lon, num_lat, d_f)
if save_ensemble:
# Transpose first dimensions, so they are (realization, time, ...)
full_fc_grid = np.moveaxis(full_fc_grid, 1, 0)
fc_sur = full_fc_grid[..., -len(constants.SURFACE_PARAMS) :]
# (..., num_sur_vars)
fc_atm = full_fc_grid[..., : -len(constants.SURFACE_PARAMS)]
# (..., num_atm_vars * num_levels)
fc_sur_list = np.split(fc_sur, len(constants.SURFACE_PARAMS), axis=-1)
fc_sur_list = [fc.squeeze(-1) for fc in fc_sur_list]
# list of ((S), B, pred_steps, num_lon, num_lat)
fc_atm_list = np.split(fc_atm, len(constants.ATMOSPHERIC_PARAMS), axis=-1)
# list of ((S), B, pred_steps, num_lon, num_lat, num_levels)
# Turn whole forecast into xr.Dataset
atm_dims, sur_dims = get_var_dims(save_ensemble)
fc_var_dict = dict(
zip(
constants.ATMOSPHERIC_PARAMS,
((atm_dims, var_vals) for var_vals in fc_atm_list),
)
) | dict(
zip(
constants.SURFACE_PARAMS,
((sur_dims, var_vals) for var_vals in fc_sur_list),
)
)
# Create dataset
batch_xds = xa.Dataset(
fc_var_dict,
coords={
"time": batch_init_times,
}
| {c: v.values for c, v in coords if c != "time"},
)
batch_xds.time.encoding["units"] = time_enc_unit
# Filter batch dataset
filtered_batch_xds = filter_xdataset(
batch_xds, var_filter_list, level_filter_list
)
# Optionally compute and add wind speeds
if (
"u_component_of_wind" in filtered_batch_xds
and "v_component_of_wind" in filtered_batch_xds
):
wind_speed = np.sqrt(
filtered_batch_xds["u_component_of_wind"] ** 2
+ filtered_batch_xds["v_component_of_wind"] ** 2
)
filtered_batch_xds["wind_speed"] = wind_speed
if (
"10m_u_component_of_wind" in filtered_batch_xds
and "10m_v_component_of_wind" in filtered_batch_xds
):
wind_speed = np.sqrt(
filtered_batch_xds["10m_u_component_of_wind"] ** 2
+ filtered_batch_xds["10m_v_component_of_wind"] ** 2
)
filtered_batch_xds["10m_wind_speed"] = wind_speed
return filtered_batch_xds
def parse_filters(var_filter_str, level_filter_str):
"""
Parse and check correctness of variable and level filters given as strings.
"""
# Variable filter
if var_filter_str is None:
var_list = None
else:
# String to list
var_list_short = [
var_str.strip() for var_str in var_filter_str.split(",")
]
# Check that all variables are forecasted
for var_str in var_list_short:
assert (
var_str in constants.ATMOSPHERIC_PARAMS_SHORT
or var_str in constants.SURFACE_PARAMS_SHORT
), f"Can not save unknown variable: {var_str}"
param_name_lookup = dict(
zip(constants.SURFACE_PARAMS_SHORT, constants.SURFACE_PARAMS)
) | dict(
zip(
constants.ATMOSPHERIC_PARAMS_SHORT, constants.ATMOSPHERIC_PARAMS
)
)
var_list = [
param_name_lookup[short_name] for short_name in var_list_short
]
# Level filter
if level_filter_str is None:
level_list = None
else:
level_list = [
int(level_str.strip()) for level_str in level_filter_str.split(",")
]
for level in level_list:
assert (
level in constants.PRESSURE_LEVELS
), f"Can not save unknown pressure level: {level}"
return var_list, level_list
def filter_xdataset(xds, var_filter_list, level_filter_list):
"""
Filter out selected variables and levels from xarray.Dataset
"""
if var_filter_list is not None:
xds = xds[var_filter_list]
if level_filter_list is not None:
# Need nearest method to keep surface variables
xds = xds.sel(level=level_filter_list, method="nearest")
return xds
@torch.no_grad()
def forecast_to_xarr(
model,
dataloader,
name,
device_name,
var_filter=None,
level_filter=None,
ens_size=5,
):
"""
Produce forecasts for each sample in the data_loader, using model
model: model to produce forecasts with
dataloader: non-shuffling dataloader for evaluation set
name: name to save zarr as (without .zarr)
device_name: name of device to use for forecasting
var_filter: string, comma-separated list of variables to save,
or None to save all
"""
# Parse var_filter
var_filter_list, level_filter_list = parse_filters(var_filter, level_filter)
# Set up device, need to handle manually here
device = torch.device(device_name)
model = model.to(device)
# Set up save path
os.makedirs(FC_DIR_PATH, exist_ok=True)
fc_path = os.path.join(FC_DIR_PATH, f"{name}.zarr")
# Get coordinates from array used in dataset
dataset = dataloader.dataset
data_mean = dataset.data_mean.to(device)
data_std = dataset.data_std.to(device)
ds_xda = dataset.atm_xda
# Set up xarray with zarr backend
pred_hours = 6 * (np.arange(dataset.pred_length) + 1)
pred_timedeltas = [
np.timedelta64(dh, "h").astype("timedelta64[ns]") for dh in pred_hours
]
# Figure out if we should do ensemble forecasting
save_ensemble = isinstance(model, GraphEFM)
# Set up dimensions for dataset
atm_dims, sur_dims = get_var_dims(save_ensemble)
atm_empty_shape = (
0,
dataset.pred_length,
*constants.GRID_SHAPE,
len(constants.PRESSURE_LEVELS),
)
sur_empty_shape = (
0,
dataset.pred_length,
*constants.GRID_SHAPE,
)
ds_coords = {
"time": np.array([], dtype="datetime64[ns]"),
"prediction_timedelta": pred_timedeltas,
"longitude": ds_xda.coords["longitude"].values,
"latitude": ds_xda.coords["latitude"].values,
"level": ds_xda.coords["level"].values,
}
if save_ensemble:
# Add on realization (ens. member) dim.
atm_empty_shape = (ens_size,) + atm_empty_shape
sur_empty_shape = (ens_size,) + sur_empty_shape
ds_coords["realization"] = np.arange(ens_size)
forecast_xds = xa.Dataset(
{
var_name: (
atm_dims,
np.zeros(atm_empty_shape),
)
for var_name in constants.ATMOSPHERIC_PARAMS
}
| { # Dict union
var_name: (
sur_dims,
np.zeros(sur_empty_shape),
)
for var_name in constants.SURFACE_PARAMS
},
coords=ds_coords,
)
# Need to set this encoding to save/load correct times from disk
time_enc_unit = "nanoseconds since 1970-01-01"
forecast_xds.time.encoding["units"] = time_enc_unit
# Filter to selected
filtered_xds = filter_xdataset(
forecast_xds, var_filter_list, level_filter_list
)
# Set up wind variables
if (
"u_component_of_wind" in filtered_xds
and "v_component_of_wind" in filtered_xds
):
# Use same empty shape
filtered_xds["wind_speed"] = filtered_xds["u_component_of_wind"]
if (
"10m_u_component_of_wind" in filtered_xds
and "10m_v_component_of_wind" in filtered_xds
):
# Use same empty shape
filtered_xds["10m_wind_speed"] = filtered_xds["10m_u_component_of_wind"]
# Set up chunking
atm_chunking = (1, -1, -1, -1, -1)
sur_chunking = (1, -1, -1, -1)
if save_ensemble:
# All members in same chunk
atm_chunking = (-1,) + atm_chunking
sur_chunking = (-1,) + sur_chunking
chunk_encoding = dict( # pylint: disable=consider-using-dict-comprehension
[
(
(v, {"chunks": atm_chunking})
if v in constants.ATMOSPHERIC_PARAMS + ["wind_speed"]
else (v, {"chunks": sur_chunking})
)
for v in filtered_xds
]
)
# Overwrite if exists
filtered_xds.to_zarr(fc_path, mode="w", encoding=chunk_encoding)
# Compute all init times
start_init_time = ds_xda.coords["time"].values[1]
end_init_time = start_init_time + np.timedelta64(len(dataset) * 12, "h")
init_times = np.arange(
start_init_time, end_init_time, np.timedelta64(12, "h")
).astype("datetime64[ns]")
# Iterate over dataset and produce forecasts
for batch in tqdm(dataloader):
# Send to device
batch = tuple(t.to(device) for t in batch)
# Forecast
if save_ensemble:
init_states, target_states, forcing_features = batch
batch_forecast, _ = model.sample_trajectories(
init_states,
forcing_features,
target_states,
ens_size,
)
# (B, S, pred_steps, num_grid_nodes, d_f)
else:
batch_forecast, _, _ = model.common_step(batch)
# Rescale to original data scaling
batch_forecast_rescaled = batch_forecast * data_std + data_mean
# (B, (S), pred_steps, num_grid_nodes, d_f)
# Get init times for batch
batch_size = batch_forecast.shape[0]
batch_init_times = init_times[:batch_size]
init_times = init_times[batch_size:] # Drop used times
batch_xds = forecast_to_xds(
batch_forecast_rescaled,
batch_init_times,
forecast_xds.coords.items(),
var_filter_list,
level_filter_list,
time_enc_unit,
)
# Save to existing zarr using append_dim="time"
batch_xds.to_zarr(fc_path, append_dim="time")