diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index 34ca90c..dfe9d31 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -15,7 +15,7 @@ from typing_extensions import Self from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator -from ocf_datapipes.utils.consts import NWP_PROVIDERS +from ocf_data_sampler.constants import NWP_PROVIDERS logger = logging.getLogger(__name__) diff --git a/ocf_data_sampler/constants.py b/ocf_data_sampler/constants.py new file mode 100644 index 0000000..d0c9a18 --- /dev/null +++ b/ocf_data_sampler/constants.py @@ -0,0 +1,135 @@ +import xarray as xr +import numpy as np + + +NWP_PROVIDERS = [ + "ukv", + "ecmwf", +] + + +def _to_data_array(d): + return xr.DataArray( + [d[k] for k in d.keys()], + coords={"channel": [k for k in d.keys()]}, + ).astype(np.float32) + + +class NWPStatDict(dict): + """Custom dictionary class to hold NWP normalization stats""" + + def __getitem__(self, key): + if key not in NWP_PROVIDERS: + raise KeyError(f"{key} is not a supported NWP provider - {NWP_PROVIDERS}") + elif key in self.keys(): + return super().__getitem__(key) + else: + raise KeyError( + f"Values for {key} not yet available in ocf-data-sampler {list(self.keys())}" + ) + +# ------ UKV +# Means and std computed WITH version_7 and higher, MetOffice values +UKV_STD = { + "cdcb": 2126.99350113, + "lcc": 39.33210726, + "mcc": 41.91144559, + "hcc": 38.07184418, + "sde": 0.1029753, + "hcct": 18382.63958991, + "dswrf": 190.47216887, + "dlwrf": 39.45988077, + "h": 1075.77812282, + "t": 4.38818501, + "r": 11.45012499, + "dpt": 4.57250482, + "vis": 21578.97975625, + "si10": 3.94718813, + "wdir10": 94.08407495, + "prmsl": 1252.71790539, + "prate": 0.00021497, +} +UKV_MEAN = { + "cdcb": 1412.26599062, + "lcc": 50.08362643, + "mcc": 40.88984494, + "hcc": 29.11949682, + "sde": 0.00289545, + "hcct": -18345.97478167, + "dswrf": 111.28265039, + "dlwrf": 325.03130139, + "h": 2096.51991356, + "t": 283.64913206, + "r": 81.79229501, + "dpt": 280.54379901, + "vis": 32262.03285118, + "si10": 6.88348448, + "wdir10": 199.41891636, + "prmsl": 101321.61574029, + "prate": 3.45793433e-05, +} + +UKV_STD = _to_data_array(UKV_STD) +UKV_MEAN = _to_data_array(UKV_MEAN) + +# ------ ECMWF +# These were calculated from 100 random init times of UK data from 2020-2023 +ECMWF_STD = { + "dlwrf": 15855867.0, + "dswrf": 13025427.0, + "duvrs": 1445635.25, + "hcc": 0.42244860529899597, + "lcc": 0.3791404366493225, + "mcc": 0.38039860129356384, + "prate": 9.81039775069803e-05, + "sde": 0.000913831521756947, + "sr": 16294988.0, + "t2m": 3.692270040512085, + "tcc": 0.37487083673477173, + "u10": 5.531515598297119, + "u100": 7.2320556640625, + "u200": 8.049470901489258, + "v10": 5.411230564117432, + "v100": 6.944501876831055, + "v200": 7.561611652374268, + "diff_dlwrf": 131942.03125, + "diff_dswrf": 715366.3125, + "diff_duvrs": 81605.25, + "diff_sr": 818950.6875, +} +ECMWF_MEAN = { + "dlwrf": 27187026.0, + "dswrf": 11458988.0, + "duvrs": 1305651.25, + "hcc": 0.3961029052734375, + "lcc": 0.44901806116104126, + "mcc": 0.3288780450820923, + "prate": 3.108070450252853e-05, + "sde": 8.107526082312688e-05, + "sr": 12905302.0, + "t2m": 283.48333740234375, + "tcc": 0.7049227356910706, + "u10": 1.7677178382873535, + "u100": 2.393547296524048, + "u200": 2.7963004112243652, + "v10": 0.985887885093689, + "v100": 1.4244288206100464, + "v200": 1.6010299921035767, + "diff_dlwrf": 1136464.0, + "diff_dswrf": 420584.6875, + "diff_duvrs": 48265.4765625, + "diff_sr": 469169.5, +} + +ECMWF_STD = _to_data_array(ECMWF_STD) +ECMWF_MEAN = _to_data_array(ECMWF_MEAN) + +NWP_STDS = NWPStatDict( + ukv=UKV_STD, + ecmwf=ECMWF_STD, +) +NWP_MEANS = NWPStatDict( + ukv=UKV_MEAN, + ecmwf=ECMWF_MEAN, +) + diff --git a/ocf_data_sampler/numpy_batch/gsp.py b/ocf_data_sampler/numpy_batch/gsp.py index f69ddad..f65516c 100644 --- a/ocf_data_sampler/numpy_batch/gsp.py +++ b/ocf_data_sampler/numpy_batch/gsp.py @@ -1,20 +1,33 @@ """Convert GSP to Numpy Batch""" import xarray as xr -from ocf_datapipes.batch import BatchKey, NumpyBatch -def convert_gsp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> NumpyBatch: +class GSPBatchKey: + + gsp = 'gsp' + gsp_nominal_capacity_mwp = 'gsp_nominal_capacity_mwp' + gsp_effective_capacity_mwp = 'gsp_effective_capacity_mwp' + gsp_time_utc = 'gsp_time_utc' + gsp_t0_idx = 'gsp_t0_idx' + gsp_solar_azimuth = 'gsp_solar_azimuth' + gsp_solar_elevation = 'gsp_solar_elevation' + gsp_id = 'gsp_id' + gsp_x_osgb = 'gsp_x_osgb' + gsp_y_osgb = 'gsp_y_osgb' + + +def convert_gsp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NumpyBatch""" - example: NumpyBatch = { - BatchKey.gsp: da.values, - BatchKey.gsp_nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values, - BatchKey.gsp_effective_capacity_mwp: da.isel(time_utc=0)["effective_capacity_mwp"].values, - BatchKey.gsp_time_utc: da["time_utc"].values.astype(float), + example = { + GSPBatchKey.gsp: da.values, + GSPBatchKey.gsp_nominal_capacity_mwp: da.isel(time_utc=0)["nominal_capacity_mwp"].values, + GSPBatchKey.gsp_effective_capacity_mwp: da.isel(time_utc=0)["effective_capacity_mwp"].values, + GSPBatchKey.gsp_time_utc: da["time_utc"].values.astype(float), } if t0_idx is not None: - example[BatchKey.gsp_t0_idx] = t0_idx + example[GSPBatchKey.gsp_t0_idx] = t0_idx return example diff --git a/ocf_data_sampler/numpy_batch/nwp.py b/ocf_data_sampler/numpy_batch/nwp.py index 062f047..a3af3b0 100644 --- a/ocf_data_sampler/numpy_batch/nwp.py +++ b/ocf_data_sampler/numpy_batch/nwp.py @@ -3,13 +3,23 @@ import pandas as pd import xarray as xr -from ocf_datapipes.batch import NWPBatchKey, NWPNumpyBatch +class NWPBatchKey: -def convert_nwp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> NWPNumpyBatch: + nwp = 'nwp' + nwp_channel_names = 'nwp_channel_names' + nwp_init_time_utc = 'nwp_init_time_utc' + nwp_step = 'nwp_step' + nwp_target_time_utc = 'nwp_target_time_utc' + nwp_t0_idx = 'nwp_t0_idx' + nwp_y_osgb = 'nwp_y_osgb' + nwp_x_osgb = 'nwp_x_osgb' + + +def convert_nwp_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NWP NumpyBatch""" - example: NWPNumpyBatch = { + example = { NWPBatchKey.nwp: da.values, NWPBatchKey.nwp_channel_names: da.channel.values, NWPBatchKey.nwp_init_time_utc: da.init_time_utc.values.astype(float), diff --git a/ocf_data_sampler/numpy_batch/satellite.py b/ocf_data_sampler/numpy_batch/satellite.py index c4c8d5b..6cdb270 100644 --- a/ocf_data_sampler/numpy_batch/satellite.py +++ b/ocf_data_sampler/numpy_batch/satellite.py @@ -1,23 +1,30 @@ """Convert Satellite to NumpyBatch""" import xarray as xr -from ocf_datapipes.batch import BatchKey, NumpyBatch +class SatelliteBatchKey: -def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> NumpyBatch: + satellite_actual = 'satellite_actual' + satellite_time_utc = 'satellite_time_utc' + satellite_x_geostationary = 'satellite_x_geostationary' + satellite_y_geostationary = 'satellite_y_geostationary' + satellite_t0_idx = 'satellite_t0_idx' + + +def convert_satellite_to_numpy_batch(da: xr.DataArray, t0_idx: int | None = None) -> dict: """Convert from Xarray to NumpyBatch""" - example: NumpyBatch = { - BatchKey.satellite_actual: da.values, - BatchKey.satellite_time_utc: da.time_utc.values.astype(float), + example = { + SatelliteBatchKey.satellite_actual: da.values, + SatelliteBatchKey.satellite_time_utc: da.time_utc.values.astype(float), } for batch_key, dataset_key in ( - (BatchKey.satellite_x_geostationary, "x_geostationary"), - (BatchKey.satellite_y_geostationary, "y_geostationary"), + (SatelliteBatchKey.satellite_x_geostationary, "x_geostationary"), + (SatelliteBatchKey.satellite_y_geostationary, "y_geostationary"), ): example[batch_key] = da[dataset_key].values if t0_idx is not None: - example[BatchKey.satellite_t0_idx] = t0_idx + example[SatelliteBatchKey.satellite_t0_idx] = t0_idx return example \ No newline at end of file diff --git a/ocf_data_sampler/numpy_batch/sun_position.py b/ocf_data_sampler/numpy_batch/sun_position.py index 0866801..3b6cb05 100644 --- a/ocf_data_sampler/numpy_batch/sun_position.py +++ b/ocf_data_sampler/numpy_batch/sun_position.py @@ -2,7 +2,6 @@ import pvlib import numpy as np import pandas as pd -from ocf_datapipes.batch import BatchKey, NumpyBatch def calculate_azimuth_and_elevation( @@ -37,8 +36,8 @@ def make_sun_position_numpy_batch( datetimes: pd.DatetimeIndex, lon: float, lat: float, - key_preffix: str = "gsp" -) -> NumpyBatch: + key_prefix: str = "gsp" +) -> dict: """Creates NumpyBatch with standardized solar coordinates Args: @@ -58,9 +57,9 @@ def make_sun_position_numpy_batch( elevation = elevation / 180 + 0.5 # Make NumpyBatch - sun_numpy_batch: NumpyBatch = { - BatchKey[key_preffix + "_solar_azimuth"]: azimuth, - BatchKey[key_preffix + "_solar_elevation"]: elevation, + sun_numpy_batch = { + key_prefix + "_solar_azimuth": azimuth, + key_prefix + "_solar_elevation": elevation, } return sun_numpy_batch \ No newline at end of file diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index ecd952a..d62f84c 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -28,17 +28,14 @@ from ocf_data_sampler.config import Configuration, load_yaml_configuration -from ocf_datapipes.batch import BatchKey, NumpyBatch +from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey +from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey -from ocf_datapipes.utils.location import Location -from ocf_datapipes.utils.geospatial import osgb_to_lon_lat +from ocf_data_sampler.select.location import Location +from ocf_data_sampler.select.geospatial import osgb_to_lon_lat -from ocf_datapipes.utils.consts import ( - NWP_MEANS, - NWP_STDS, -) +from ocf_data_sampler.constants import NWP_MEANS, NWP_STDS -from ocf_datapipes.training.common import concat_xr_time_utc, normalize_gsp @@ -343,7 +340,7 @@ def slice_datasets_by_time( return sliced_datasets_dict -def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch: +def fill_nans_in_arrays(batch: dict) -> dict: """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. Operation is performed in-place on the batch. @@ -375,7 +372,7 @@ def process_and_combine_datasets( config: Configuration, t0: pd.Timedelta, location: Location, - ) -> NumpyBatch: + ) -> dict: """Normalize and convert data to numpy arrays""" numpy_modalities = [] @@ -392,7 +389,7 @@ def process_and_combine_datasets( nwp_numpy_modalities[nwp_key] = convert_nwp_to_numpy_batch(da_nwp) # Combine the NWPs into NumpyBatch - numpy_modalities.append({BatchKey.nwp: nwp_numpy_modalities}) + numpy_modalities.append({NWPBatchKey.nwp: nwp_numpy_modalities}) if "sat" in dataset_dict: # Satellite is already in the range [0-1] so no need to standardise @@ -404,8 +401,8 @@ def process_and_combine_datasets( gsp_config = config.input_data.gsp if "gsp" in dataset_dict: - da_gsp = concat_xr_time_utc([dataset_dict["gsp"], dataset_dict["gsp_future"]]) - da_gsp = normalize_gsp(da_gsp) + da_gsp = xr.concat([dataset_dict["gsp"], dataset_dict["gsp_future"]], dim="time_utc") + da_gsp = da_gsp / da_gsp.effective_capacity_mwp numpy_modalities.append( convert_gsp_to_numpy_batch( @@ -428,9 +425,9 @@ def process_and_combine_datasets( # Add coordinate data # TODO: Do we need all of these? numpy_modalities.append({ - BatchKey.gsp_id: location.id, - BatchKey.gsp_x_osgb: location.x, - BatchKey.gsp_y_osgb: location.y, + GSPBatchKey.gsp_id: location.id, + GSPBatchKey.gsp_x_osgb: location.x, + GSPBatchKey.gsp_y_osgb: location.y, }) # Combine all the modalities and fill NaNs @@ -538,7 +535,7 @@ def __len__(self): return len(self.index_pairs) - def _get_sample(self, t0: pd.Timestamp, location: Location) -> NumpyBatch: + def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict: """Generate the PVNet sample for given coordinates Args: @@ -565,7 +562,7 @@ def __getitem__(self, idx): return self._get_sample(t0, location) - def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> NumpyBatch: + def get_sample(self, t0: pd.Timestamp, gsp_id: int) -> dict: """Generate a sample for the given coordinates. Useful for users to generate samples by GSP ID. diff --git a/pyproject.toml b/pyproject.toml index 532ca08..c17c11b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,14 +18,19 @@ maintainers = [ ] dependencies = [ # Migration from requirements.txt + "torch", "numpy", "pandas", "xarray", "zarr", "dask", "ocf_blosc2", - "ocf_datapipes==3.3.39", - "pvlib" + "pvlib", + "pydantic", + "pyproj", + "pathy", + "pyaml_env", + "pyresample" ] keywords = [ # I've added some keywords, but please provide feedback if you'd like them changed! diff --git a/tests/numpy_batch/test_gsp.py b/tests/numpy_batch/test_gsp.py index 7803dd6..2afa94d 100644 --- a/tests/numpy_batch/test_gsp.py +++ b/tests/numpy_batch/test_gsp.py @@ -1,8 +1,7 @@ -from ocf_datapipes.batch import BatchKey from ocf_data_sampler.load.gsp import open_gsp from ocf_data_sampler.numpy_batch import convert_gsp_to_numpy_batch - +from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey def test_convert_gsp_to_numpy_batch(uk_gsp_zarr_path): @@ -19,5 +18,5 @@ def test_convert_gsp_to_numpy_batch(uk_gsp_zarr_path): assert isinstance(numpy_batch, dict) # Assert the shape of the numpy batch - assert (numpy_batch[BatchKey.gsp] == da.values).all() + assert (numpy_batch[GSPBatchKey.gsp] == da.values).all() diff --git a/tests/numpy_batch/test_nwp.py b/tests/numpy_batch/test_nwp.py index 54c922e..50b0c14 100644 --- a/tests/numpy_batch/test_nwp.py +++ b/tests/numpy_batch/test_nwp.py @@ -6,7 +6,7 @@ from ocf_data_sampler.numpy_batch import convert_nwp_to_numpy_batch -from ocf_datapipes.batch import NWPBatchKey +from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey @pytest.fixture(scope="module") def da_nwp_like(): diff --git a/tests/numpy_batch/test_satellite.py b/tests/numpy_batch/test_satellite.py index 7da39e8..9be5a62 100644 --- a/tests/numpy_batch/test_satellite.py +++ b/tests/numpy_batch/test_satellite.py @@ -7,7 +7,7 @@ from ocf_data_sampler.numpy_batch import convert_satellite_to_numpy_batch -from ocf_datapipes.batch import BatchKey +from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey @pytest.fixture(scope="module") @@ -39,4 +39,4 @@ def test_convert_satellite_to_numpy_batch(da_sat_like): assert isinstance(numpy_batch, dict) # Assert the shape of the numpy batch - assert (numpy_batch[BatchKey.satellite_actual] == da_sat_like.values).all() \ No newline at end of file + assert (numpy_batch[SatelliteBatchKey.satellite_actual] == da_sat_like.values).all() \ No newline at end of file diff --git a/tests/numpy_batch/test_sun_position.py b/tests/numpy_batch/test_sun_position.py index 6077261..2c55119 100644 --- a/tests/numpy_batch/test_sun_position.py +++ b/tests/numpy_batch/test_sun_position.py @@ -6,7 +6,7 @@ calculate_azimuth_and_elevation, make_sun_position_numpy_batch ) -from ocf_datapipes.batch import NumpyBatch, BatchKey +from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey @pytest.mark.parametrize("lat", [0, 5, 10, 23.5]) @@ -69,13 +69,13 @@ def test_make_sun_position_numpy_batch(): datetimes = pd.date_range("2024-06-20 12:00", "2024-06-20 16:00", freq="30min") lon, lat = 0, 51.5 - batch = make_sun_position_numpy_batch(datetimes, lon, lat, key_preffix="gsp") + batch = make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix="gsp") - assert BatchKey.gsp_solar_elevation in batch - assert BatchKey.gsp_solar_azimuth in batch + assert GSPBatchKey.gsp_solar_elevation in batch + assert GSPBatchKey.gsp_solar_azimuth in batch # The solar coords are normalised in the function - assert (batch[BatchKey.gsp_solar_elevation]>=0).all() - assert (batch[BatchKey.gsp_solar_elevation]<=1).all() - assert (batch[BatchKey.gsp_solar_azimuth]>=0).all() - assert (batch[BatchKey.gsp_solar_azimuth]<=1).all() + assert (batch[GSPBatchKey.gsp_solar_elevation]>=0).all() + assert (batch[GSPBatchKey.gsp_solar_elevation]<=1).all() + assert (batch[GSPBatchKey.gsp_solar_azimuth]>=0).all() + assert (batch[GSPBatchKey.gsp_solar_azimuth]<=1).all() diff --git a/tests/select/test_select_spatial_slice.py b/tests/select/test_select_spatial_slice.py index 2d7d5bc..cdb6735 100644 --- a/tests/select/test_select_spatial_slice.py +++ b/tests/select/test_select_spatial_slice.py @@ -1,6 +1,6 @@ import numpy as np import xarray as xr -from ocf_datapipes.utils import Location +from ocf_data_sampler.select.location import Location import pytest from ocf_data_sampler.select.select_spatial_slice import ( diff --git a/tests/torch_datasets/test_pvnet_uk_regional.py b/tests/torch_datasets/test_pvnet_uk_regional.py index 490c7b5..1dfe268 100644 --- a/tests/torch_datasets/test_pvnet_uk_regional.py +++ b/tests/torch_datasets/test_pvnet_uk_regional.py @@ -3,7 +3,9 @@ from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration -from ocf_datapipes.batch import BatchKey, NWPBatchKey +from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey +from ocf_data_sampler.numpy_batch.gsp import GSPBatchKey +from ocf_data_sampler.numpy_batch.satellite import SatelliteBatchKey @pytest.fixture() @@ -36,24 +38,24 @@ def test_pvnet(pvnet_config_filename): assert isinstance(sample, dict) for key in [ - BatchKey.nwp, BatchKey.satellite_actual, BatchKey.gsp, - BatchKey.gsp_solar_azimuth, BatchKey.gsp_solar_elevation, + NWPBatchKey.nwp, SatelliteBatchKey.satellite_actual, GSPBatchKey.gsp, + GSPBatchKey.gsp_solar_azimuth, GSPBatchKey.gsp_solar_elevation, ]: assert key in sample for nwp_source in ["ukv"]: - assert nwp_source in sample[BatchKey.nwp] + assert nwp_source in sample[NWPBatchKey.nwp] # check the shape of the data is correct # 30 minutes of 5 minute data (inclusive), one channel, 2x2 pixels - assert sample[BatchKey.satellite_actual].shape == (7, 1, 2, 2) + assert sample[SatelliteBatchKey.satellite_actual].shape == (7, 1, 2, 2) # 3 hours of 60 minute data (inclusive), one channel, 2x2 pixels - assert sample[BatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) + assert sample[NWPBatchKey.nwp]["ukv"][NWPBatchKey.nwp].shape == (4, 1, 2, 2) # 3 hours of 30 minute data (inclusive) - assert sample[BatchKey.gsp].shape == (7,) + assert sample[GSPBatchKey.gsp].shape == (7,) # Solar angles have same shape as GSP data - assert sample[BatchKey.gsp_solar_azimuth].shape == (7,) - assert sample[BatchKey.gsp_solar_elevation].shape == (7,) + assert sample[GSPBatchKey.gsp_solar_azimuth].shape == (7,) + assert sample[GSPBatchKey.gsp_solar_elevation].shape == (7,) def test_pvnet_no_gsp(pvnet_config_filename):