Skip to content

Commit

Permalink
Merge pull request #68 from openclimatefix/rm-ocf-datapipes
Browse files Browse the repository at this point in the history
Rm ocf datapipes
  • Loading branch information
peterdudfield authored Oct 21, 2024
2 parents 83ec958 + 78f618d commit c16c9d6
Show file tree
Hide file tree
Showing 14 changed files with 237 additions and 70 deletions.
2 changes: 1 addition & 1 deletion ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing_extensions import Self

from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
from ocf_datapipes.utils.consts import NWP_PROVIDERS
from ocf_data_sampler.constants import NWP_PROVIDERS

logger = logging.getLogger(__name__)

Expand Down
135 changes: 135 additions & 0 deletions ocf_data_sampler/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import xarray as xr
import numpy as np


NWP_PROVIDERS = [
"ukv",
"ecmwf",
]


def _to_data_array(d):
return xr.DataArray(
[d[k] for k in d.keys()],
coords={"channel": [k for k in d.keys()]},
).astype(np.float32)


class NWPStatDict(dict):
"""Custom dictionary class to hold NWP normalization stats"""

def __getitem__(self, key):
if key not in NWP_PROVIDERS:
raise KeyError(f"{key} is not a supported NWP provider - {NWP_PROVIDERS}")
elif key in self.keys():
return super().__getitem__(key)
else:
raise KeyError(
f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}"
)

# ------ UKV
# Means and std computed WITH version_7 and higher, MetOffice values
UKV_STD = {
"cdcb": 2126.99350113,
"lcc": 39.33210726,
"mcc": 41.91144559,
"hcc": 38.07184418,
"sde": 0.1029753,
"hcct": 18382.63958991,
"dswrf": 190.47216887,
"dlwrf": 39.45988077,
"h": 1075.77812282,
"t": 4.38818501,
"r": 11.45012499,
"dpt": 4.57250482,
"vis": 21578.97975625,
"si10": 3.94718813,
"wdir10": 94.08407495,
"prmsl": 1252.71790539,
"prate": 0.00021497,
}
UKV_MEAN = {
"cdcb": 1412.26599062,
"lcc": 50.08362643,
"mcc": 40.88984494,
"hcc": 29.11949682,
"sde": 0.00289545,
"hcct": -18345.97478167,
"dswrf": 111.28265039,
"dlwrf": 325.03130139,
"h": 2096.51991356,
"t": 283.64913206,
"r": 81.79229501,
"dpt": 280.54379901,
"vis": 32262.03285118,
"si10": 6.88348448,
"wdir10": 199.41891636,
"prmsl": 101321.61574029,
"prate": 3.45793433e-05,
}

UKV_STD = _to_data_array(UKV_STD)
UKV_MEAN = _to_data_array(UKV_MEAN)

# ------ ECMWF
# These were calculated from 100 random init times of UK data from 2020-2023
ECMWF_STD = {
"dlwrf": 15855867.0,
"dswrf": 13025427.0,
"duvrs": 1445635.25,
"hcc": 0.42244860529899597,
"lcc": 0.3791404366493225,
"mcc": 0.38039860129356384,
"prate": 9.81039775069803e-05,
"sde": 0.000913831521756947,
"sr": 16294988.0,
"t2m": 3.692270040512085,
"tcc": 0.37487083673477173,
"u10": 5.531515598297119,
"u100": 7.2320556640625,
"u200": 8.049470901489258,
"v10": 5.411230564117432,
"v100": 6.944501876831055,
"v200": 7.561611652374268,
"diff_dlwrf": 131942.03125,
"diff_dswrf": 715366.3125,
"diff_duvrs": 81605.25,
"diff_sr": 818950.6875,
}
ECMWF_MEAN = {
"dlwrf": 27187026.0,
"dswrf": 11458988.0,
"duvrs": 1305651.25,
"hcc": 0.3961029052734375,
"lcc": 0.44901806116104126,
"mcc": 0.3288780450820923,
"prate": 3.108070450252853e-05,
"sde": 8.107526082312688e-05,
"sr": 12905302.0,
"t2m": 283.48333740234375,
"tcc": 0.7049227356910706,
"u10": 1.7677178382873535,
"u100": 2.393547296524048,
"u200": 2.7963004112243652,
"v10": 0.985887885093689,
"v100": 1.4244288206100464,
"v200": 1.6010299921035767,
"diff_dlwrf": 1136464.0,
"diff_dswrf": 420584.6875,
"diff_duvrs": 48265.4765625,
"diff_sr": 469169.5,
}

ECMWF_STD = _to_data_array(ECMWF_STD)
ECMWF_MEAN = _to_data_array(ECMWF_MEAN)

NWP_STDS = NWPStatDict(
ukv=UKV_STD,
ecmwf=ECMWF_STD,
)
NWP_MEANS = NWPStatDict(
ukv=UKV_MEAN,
ecmwf=ECMWF_MEAN,
)

29 changes: 21 additions & 8 deletions ocf_data_sampler/numpy_batch/gsp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
"""Convert GSP to Numpy Batch"""

import xarray as xr
from ocf_datapipes.batch import BatchKey, NumpyBatch


def convert_gsp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> NumpyBatch:
class GSPBatchKey:

gsp = 'gsp'
gsp_nominal_capacity_mwp = 'gsp_nominal_capacity_mwp'
gsp_effective_capacity_mwp = 'gsp_effective_capacity_mwp'
gsp_time_utc = 'gsp_time_utc'
gsp_t0_idx = 'gsp_t0_idx'
gsp_solar_azimuth = 'gsp_solar_azimuth'
gsp_solar_elevation = 'gsp_solar_elevation'
gsp_id = 'gsp_id'
gsp_x_osgb = 'gsp_x_osgb'
gsp_y_osgb = 'gsp_y_osgb'


def convert_gsp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
"""Convert from Xarray to NumpyBatch"""

example: NumpyBatch = {
BatchKey.gsp: da.values,
BatchKey.gsp_nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
BatchKey.gsp_effective_capacity_mwp: da.isel(time_utc=0)["effective_capacity_mwp"].values,
BatchKey.gsp_time_utc: da["time_utc"].values.astype(float),
example = {
GSPBatchKey.gsp: da.values,
GSPBatchKey.gsp_nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values,
GSPBatchKey.gsp_effective_capacity_mwp: da.isel(time_utc=0)["effective_capacity_mwp"].values,
GSPBatchKey.gsp_time_utc: da["time_utc"].values.astype(float),
}

if t0_idx is not None:
example[BatchKey.gsp_t0_idx] = t0_idx
example[GSPBatchKey.gsp_t0_idx] = t0_idx

return example
16 changes: 13 additions & 3 deletions ocf_data_sampler/numpy_batch/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@
import pandas as pd
import xarray as xr

from ocf_datapipes.batch import NWPBatchKey, NWPNumpyBatch

class NWPBatchKey:

def convert_nwp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> NWPNumpyBatch:
nwp = 'nwp'
nwp_channel_names = 'nwp_channel_names'
nwp_init_time_utc = 'nwp_init_time_utc'
nwp_step = 'nwp_step'
nwp_target_time_utc = 'nwp_target_time_utc'
nwp_t0_idx = 'nwp_t0_idx'
nwp_y_osgb = 'nwp_y_osgb'
nwp_x_osgb = 'nwp_x_osgb'


def convert_nwp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
"""Convert from Xarray to NWP NumpyBatch"""

example: NWPNumpyBatch = {
example = {
NWPBatchKey.nwp: da.values,
NWPBatchKey.nwp_channel_names: da.channel.values,
NWPBatchKey.nwp_init_time_utc: da.init_time_utc.values.astype(float),
Expand Down
23 changes: 15 additions & 8 deletions ocf_data_sampler/numpy_batch/satellite.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
"""Convert Satellite to NumpyBatch"""
import xarray as xr

from ocf_datapipes.batch import BatchKey, NumpyBatch

class SatelliteBatchKey:

def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> NumpyBatch:
satellite_actual = 'satellite_actual'
satellite_time_utc = 'satellite_time_utc'
satellite_x_geostationary = 'satellite_x_geostationary'
satellite_y_geostationary = 'satellite_y_geostationary'
satellite_t0_idx = 'satellite_t0_idx'


def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict:
"""Convert from Xarray to NumpyBatch"""
example: NumpyBatch = {
BatchKey.satellite_actual: da.values,
BatchKey.satellite_time_utc: da.time_utc.values.astype(float),
example = {
SatelliteBatchKey.satellite_actual: da.values,
SatelliteBatchKey.satellite_time_utc: da.time_utc.values.astype(float),
}

for batch_key, dataset_key in (
(BatchKey.satellite_x_geostationary, "x_geostationary"),
(BatchKey.satellite_y_geostationary, "y_geostationary"),
(SatelliteBatchKey.satellite_x_geostationary, "x_geostationary"),
(SatelliteBatchKey.satellite_y_geostationary, "y_geostationary"),
):
example[batch_key] = da[dataset_key].values

if t0_idx is not None:
example[BatchKey.satellite_t0_idx] = t0_idx
example[SatelliteBatchKey.satellite_t0_idx] = t0_idx

return example
11 changes: 5 additions & 6 deletions ocf_data_sampler/numpy_batch/sun_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pvlib
import numpy as np
import pandas as pd
from ocf_datapipes.batch import BatchKey, NumpyBatch


def calculate_azimuth_and_elevation(
Expand Down Expand Up @@ -37,8 +36,8 @@ def make_sun_position_numpy_batch(
datetimes: pd.DatetimeIndex,
lon: float,
lat: float,
key_preffix: str = "gsp"
) -> NumpyBatch:
key_prefix: str = "gsp"
) -> dict:
"""Creates NumpyBatch with standardized solar coordinates
Args:
Expand All @@ -58,9 +57,9 @@ def make_sun_position_numpy_batch(
elevation = elevation / 180 + 0.5

# Make NumpyBatch
sun_numpy_batch: NumpyBatch = {
BatchKey[key_preffix + "_solar_azimuth"]: azimuth,
BatchKey[key_preffix + "_solar_elevation"]: elevation,
sun_numpy_batch = {
key_prefix + "_solar_azimuth": azimuth,
key_prefix + "_solar_elevation": elevation,
}

return sun_numpy_batch
33 changes: 15 additions & 18 deletions ocf_data_sampler/torch_datasets/pvnet_uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,14 @@


from ocf_data_sampler.config import Configuration, load_yaml_configuration
from ocf_datapipes.batch import BatchKey, NumpyBatch
from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey
from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey

from ocf_datapipes.utils.location import Location
from ocf_datapipes.utils.geospatial import osgb_to_lon_lat
from ocf_data_sampler.select.location import Location
from ocf_data_sampler.select.geospatial import osgb_to_lon_lat

from ocf_datapipes.utils.consts import (
NWP_MEANS,
NWP_STDS,
)
from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS

from ocf_datapipes.training.common import concat_xr_time_utc, normalize_gsp



Expand Down Expand Up @@ -343,7 +340,7 @@ def slice_datasets_by_time(
return sliced_datasets_dict


def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch:
def fill_nans_in_arrays(batch: dict) -> dict:
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
Operation is performed in-place on the batch.
Expand Down Expand Up @@ -375,7 +372,7 @@ def process_and_combine_datasets(
config: Configuration,
t0: pd.Timedelta,
location: Location,
) -> NumpyBatch:
) -> dict:
"""Normalize and convert data to numpy arrays"""

numpy_modalities = []
Expand All @@ -392,7 +389,7 @@ def process_and_combine_datasets(
nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp)

# Combine the NWPs into NumpyBatch
numpy_modalities.append({BatchKey.nwp: nwp_numpy_modalities})
numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities})

if "sat" in dataset_dict:
# Satellite is already in the range [0-1] so no need to standardise
Expand All @@ -404,8 +401,8 @@ def process_and_combine_datasets(
gsp_config = config.input_data.gsp

if "gsp" in dataset_dict:
da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]])
da_gsp = normalize_gsp(da_gsp)
da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc")
da_gsp = da_gsp / da_gsp.effective_capacity_mwp

numpy_modalities.append(
convert_gsp_to_numpy_batch(
Expand All @@ -428,9 +425,9 @@ def process_and_combine_datasets(
# Add coordinate data
# TODO: Do we need all of these?
numpy_modalities.append({
BatchKey.gsp_id: location.id,
BatchKey.gsp_x_osgb: location.x,
BatchKey.gsp_y_osgb: location.y,
GSPBatchKey.gsp_id: location.id,
GSPBatchKey.gsp_x_osgb: location.x,
GSPBatchKey.gsp_y_osgb: location.y,
})

# Combine all the modalities and fill NaNs
Expand Down Expand Up @@ -538,7 +535,7 @@ def __len__(self):
return len(self.index_pairs)


def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpyBatch:
def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict:
"""Generate the PVNet sample for given coordinates
Args:
Expand All @@ -565,7 +562,7 @@ def __getitem__(self, idx):
return self._get_sample(t0, location)


def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch:
def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict:
"""Generate a sample for the given coordinates.
Useful for users to generate samples by GSP ID.
Expand Down
Loading

0 comments on commit c16c9d6

Please sign in to comment.