From aea9a31e0ad274613f59ee16484f96d076173286 Mon Sep 17 00:00:00 2001 From: devsjc <47188100+devsjc@users.noreply.github.com> Date: Tue, 13 Aug 2024 13:08:30 +0100 Subject: [PATCH] feat(sat-container): Start adding tests --- containers/sat/Containerfile | 2 +- containers/sat/download_process_sat.py | 182 ++++++++++---------- containers/sat/test_download_process_sat.py | 113 ++++++++++++ 3 files changed, 209 insertions(+), 88 deletions(-) create mode 100644 containers/sat/test_download_process_sat.py diff --git a/containers/sat/Containerfile b/containers/sat/Containerfile index 03c792e..6e76b21 100644 --- a/containers/sat/Containerfile +++ b/containers/sat/Containerfile @@ -5,7 +5,7 @@ FROM quay.io/condaforge/miniforge3:latest AS build-venv RUN apt -qq update && apt -qq install -y build-essential RUN conda create -p /venv python=3.12 RUN /venv/bin/pip install --upgrade -q pip wheel setuptools -RUN conda install -p /venv -c conda-forge -y cartopy satpy[all] numpy +RUN conda install -p /venv -c conda-forge -y cartopy satpy[all]=0.50.0 numpy ENV GDAL_CONFIG=/venv/bin/gdal-config # Build the virtualenv diff --git a/containers/sat/download_process_sat.py b/containers/sat/download_process_sat.py index d5bea9a..7adb63c 100644 --- a/containers/sat/download_process_sat.py +++ b/containers/sat/download_process_sat.py @@ -2,11 +2,13 @@ Consolidates the old cli_downloader, backfill_hrv and backfill_nonhrv scripts. """ + import argparse import dataclasses import datetime as dt import itertools import json +import traceback import logging import os import pathlib @@ -15,15 +17,17 @@ from multiprocessing import Pool, cpu_count from typing import Literal -import diskcache as dc import eumdac import eumdac.cli import numpy as np import pandas as pd import pyproj import pyresample +import satpy.dataset.dataid import xarray as xr import yaml +import dask.delayed +import dask.distributed import zarr from ocf_blosc2 import Blosc2 from satpy import Scene @@ -230,15 +234,12 @@ def process_scans( """ # Check zarr file exists for the year zarr_path: pathlib.Path = folder.parent / start.strftime(sat_config.zarr_fmtstr[dstype]) + zarr_times: list[dt.datetime] = [] if zarr_path.exists(): - zarr_times: list[dt.datetime] = xr.open_zarr(zarr_path).sortby("time").time.values.tolist() - last_zarr_time: dt.datetime = zarr_times[-1] + zarr_times = xr.open_zarr(zarr_path).sortby("time").time.values.tolist() log.debug(f"Zarr store already exists at {zarr_path} for {zarr_times[0]}-{zarr_times[-1]}") else: - # Set dummy values for times already in zarr - last_zarr_time = dt.datetime(1970, 1, 1, tzinfo=dt.UTC) - zarr_times = [last_zarr_time, last_zarr_time] - log.debug(f"Zarr store does not exist at {zarr_path}, setting dummy times") + log.debug(f"Zarr store does not exist at {zarr_path}") # Get native files in order native_files: list[pathlib.Path] = list(folder.glob("*.nat")) @@ -268,29 +269,42 @@ def process_scans( else: log.debug(f"Creating new zarr store at {zarr_path}") mode = "w" + concat_ds: xr.Dataset = xr.concat(datasets, dim="time") _write_to_zarr( - xr.concat(datasets, dim="time"), + concat_ds, zarr_path.as_posix(), mode, - chunks={"time": 12}, + chunks={ + "time": 12, + }, ) datasets = [] log.info(f"Process loop [{dstype}]: {i+1}/{len(native_files)}") # Consolidate zarr metadata - _rewrite_zarr_times(zarr_path.as_posix()) + if pathlib.Path(zarr_path).exists(): + _rewrite_zarr_times(zarr_path.as_posix()) return dstype +def _gen_token() -> eumdac.AccessToken: + """Generated an aces token from environment variables.""" + consumer_key: str = os.environ["EUMETSAT_CONSUMER_KEY"] + consumer_secret: str = os.environ["EUMETSAT_CONSUMER_SECRET"] + token = eumdac.AccessToken(credentials=(consumer_key, consumer_secret)) + + return token + + def _convert_scene_to_dataarray( scene: Scene, band: str, area: str, calculate_osgb: bool = True, ) -> xr.DataArray: - """Convertes a Scene with satellite data into a data array. + """Converts a Scene with satellite data into a data array. Args: scene: The satpy.Scene containing the satellite data @@ -314,14 +328,24 @@ def _convert_scene_to_dataarray( log.debug("Finished resample") scene = scene.crop(ll_bbox=GEOGRAPHIC_BOUNDS[area]) log.debug("Finished crop") - # Remove acq time from all bands because it is not useful, and can actually - # get in the way of combining multiple Zarr datasets. + + # Update the dataarray attributes based off of the satpy scene attributes data_attrs = {} for channel in scene.wishlist: + # Remove acq time from all bands because it is not useful, and can actually + # get in the way of combining multiple Zarr datasets. scene[channel] = scene[channel].drop_vars("acq_time", errors="ignore") for attr in scene[channel].attrs: new_name = channel["name"] + "_" + attr - data_attrs[new_name] = scene[channel].attrs[attr] + # Ignore the "area" and "_satpy_id" scene attributes as they are not serializable + # and their data is already present in other scene attrs anyway. + if attr not in ["area", "_satpy_id"]: + try: + serialized_value = json.dumps(scene[channel].attrs[attr]) + data_attrs[new_name] = serialized_value + except Exception as e: + log.warning(f"Could not serialize scene attribute {new_name}: {e}") + dataset: xr.Dataset = scene.to_xarray_dataset() dataarray = dataset.to_array() log.debug("Converted to dataarray") @@ -348,8 +372,8 @@ def _convert_scene_to_dataarray( dataarray[name].attrs["coordinate_reference_system"] = "geostationary" log.debug("Calculated OSGB") # Round to the nearest 5 minutes - dataarray.attrs.update(data_attrs) - dataarray.attrs["end_time"] = pd.Timestamp(dataarray.attrs["end_time"]).round("5 min") + data_attrs["end_time"] = pd.Timestamp(dataarray.attrs["end_time"]).round("5 min").__str__() + dataarray.attrs = data_attrs # Rename x and y to make clear the coordinate system they are in dataarray = dataarray.rename({"x": "x_geostationary", "y": "y_geostationary"}) @@ -363,41 +387,7 @@ def _convert_scene_to_dataarray( return dataarray -def _serialize_attrs(attrs: dict) -> dict: - """Ensure each value of dict can be serialized. - - This is required before saving to Zarr because Zarr represents attrs values in a - JSON file (.zmetadata). - - The `area` field (which is a `pyresample.geometry.AreaDefinition` object gets turned - into a YAML string, which can be loaded again using - `area_definition = pyresample.area_config.load_area_from_string(data_array.attrs['area'])` - - Returns attrs dict where every value has been made serializable. - """ - for key, value in attrs.items(): - # Convert Dicts - if isinstance(value, dict): - # Convert np.float32 to Python floats (otherwise yaml.dump complains) - for inner_key in value: - inner_value = value[inner_key] - if isinstance(inner_value, np.floating): - value[inner_key] = float(inner_value) - attrs[key] = yaml.dump(value) - # Convert Numpy bools - if isinstance(value, bool | np.bool_): - attrs[key] = str(value) - # Convert area - if isinstance(value, pyresample.geometry.AreaDefinition): - attrs[key] = value.dump() - # Convert datetimes - if isinstance(value, dt.datetime): - attrs[key] = value.isoformat() - - return attrs - - -def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray | None: +def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray: """Rescale Xarray DataArray so all values lie in the range [0, 1]. Warning: The original `dataarray` will be modified in-place. @@ -420,11 +410,12 @@ def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray | "variable", ) + # For each channel, subtract the minimum and divide by the range dataarray -= [c.minimum for c in channels] dataarray /= [c.maximum - c.minimum for c in channels] + # Since the mins and maxes are approximations, clip the values to [0, 1] dataarray = dataarray.clip(min=0, max=1) dataarray = dataarray.astype(np.float32) - dataarray.attrs = _serialize_attrs(dataarray.attrs) # Must be serializable return dataarray @@ -433,30 +424,45 @@ def _open_and_scale_data( f: str, dstype: Literal["hrv", "nonhrv"], ) -> xr.Dataset | None: - """Opens a raw file and converts it to a normalised xarray dataset.""" + """Opens a raw file and converts it to a normalised xarray dataset. + + Args: + zarr_times: List of times already in the zarr store. + f: Path to the file to open. + dstype: Type of data to process (hrv or nonhrv). + """ # The reader is the same for each satellite as the sensor is the same - # * Hence "severi" in all cases + # * Hence "seviri" in all cases scene = Scene(filenames={"seviri_l1b_native": [f]}) scene.load([c.variable for c in CHANNELS[dstype]]) - da: xr.DataArray = _convert_scene_to_dataarray( - scene, - band=CHANNELS[dstype][0].variable, - area="RSS", - calculate_osgb=False, - ) - # Rescale the data, update the attributes, save as dataset - attrs = _serialize_attrs(da.attrs) - da = _rescale(da, CHANNELS[dstype]) - da.attrs.update(attrs) - da = da.transpose("time", "y_geostationary", "x_geostationary", "variable") - ds: xr.Dataset = da.to_dataset(name="data") - ds["data"] = ds.data.astype(np.float16) + try: + da: xr.DataArray = _convert_scene_to_dataarray( + scene, + band=CHANNELS[dstype][0].variable, + area="RSS", + calculate_osgb=False, + ) + except Exception as e: + log.error(f"Error converting scene to dataarray: {e}") + return None - if ds.time.values[0] in zarr_times: - log.debug(f"Skipping: {ds.time.values[0]}") + # Don't proceed if the dataarray time is already present in the zarr store + if da.time.values[0] in zarr_times: + log.debug(f"Skipping: {da.time.values[0]}") return None + # Rescale the data, save as dataset + try: + da = _rescale(da, CHANNELS[dstype]) + except Exception as e: + log.error(f"Error rescaling dataarray: {e}") + return None + + da = da.transpose("time", "y_geostationary", "x_geostationary", "variable") + ds: xr.Dataset = da.to_dataset(name="data", promote_attrs=True) + ds["data"] = ds["data"].astype(np.float16) + return ds @@ -465,12 +471,12 @@ def _preprocess_function(xr_data: xr.Dataset) -> xr.Dataset: attrs = xr_data.attrs y_coords = xr_data.coords["y_geostationary"].values x_coords = xr_data.coords["x_geostationary"].values - x_dataarray = xr.DataArray( + x_dataarray: xr.DataArray = xr.DataArray( data=np.expand_dims(xr_data.coords["x_geostationary"].values, axis=0), dims=["time", "x_geostationary"], coords={"time": xr_data.coords["time"].values, "x_geostationary": x_coords}, ) - y_dataarray = xr.DataArray( + y_dataarray: xr.DataArray = xr.DataArray( data=np.expand_dims(xr_data.coords["y_geostationary"].values, axis=0), dims=["time", "y_geostationary"], coords={"time": xr_data.coords["time"].values, "y_geostationary": y_coords}, @@ -498,21 +504,25 @@ def _write_to_zarr(dataset: xr.Dataset, zarr_name: str, mode: str, chunks: dict) extra_kwargs = mode_extra_kwargs[mode] sliced_ds: xr.Dataset = dataset.isel(x_geostationary=slice(0, 5548)).chunk(chunks) try: - sliced_ds.to_zarr( + write_job: dask.delayed.Delayed = sliced_ds.to_zarr( store=zarr_name, - compute=True, + compute=False, + **extra_kwargs, consolidated=True, mode=mode, - **extra_kwargs, ) + write_job = write_job.persist() + dask.distributed.progress(write_job, notebook=False) except Exception as e: - log.error(f"Error writing to zarr: {e}") + log.error(f"Error writing dataset to zarr store {zarr_name} with mode {mode}: {e}") + traceback.print_tb(e.__traceback__) + return None def _rewrite_zarr_times(output_name: str) -> None: """Rewrites the time coordinates in the given zarr store.""" # Combine time coords - ds = xr.open_zarr(output_name) + ds = xr.open_zarr(output_name, consolidated=False) # Prevent numcodecs string error # See https://github.com/pydata/xarray/issues/3476#issuecomment-1205346130 @@ -524,6 +534,10 @@ def _rewrite_zarr_times(output_name: str) -> None: ds[v].encoding.clear() del ds["data"] + if "x_geostationary_coordinates" in ds: + del ds["x_geostationary_coordinates"] + if "y_geostationary_coordinates" in ds: + del ds["y_geostationary_coordinates"] # Need to remove these encodings to avoid chunking del ds.time.encoding["chunks"] del ds.time.encoding["preferred_chunks"] @@ -541,7 +555,7 @@ def _rewrite_zarr_times(output_name: str) -> None: data["metadata"]["time/.zarray"] = coord_data["metadata"]["time/.zarray"] with open(f"{output_name}/.zmetadata", "w") as f: json.dump(data, f) - zarr.consolidate_metadata(output_name) + # zarr.consolidate_metadata(output_name) parser = argparse.ArgumentParser( @@ -583,9 +597,6 @@ def _rewrite_zarr_times(output_name: str) -> None: args = parser.parse_args() folder: pathlib.Path = args.path / args.sat - # Create a reusable cache - cache = dc.Cache(folder / ".cache/{args.sat}") - log.info(f"{prog_start!s}: Running with args: {args}") # Get config for desired satellite sat_config = CONFIGS[args.sat] @@ -599,17 +610,14 @@ def _rewrite_zarr_times(output_name: str) -> None: freq=sat_config.cadence, ).tolist() - # Get average runtime from cache - secs_per_scan = cache.get("secs_per_scan", default=90) + # Estimate average runtime + secs_per_scan: int = 90 expected_runtime = pd.Timedelta(secs_per_scan * len(scan_times), "seconds") log.info(f"Downloading {len(scan_times)} scans. Expected runtime: {expected_runtime!s}") # Download data # We only parallelize if we have a number of files larger than the cpu count - consumer_key: str = os.environ["EUMETSAT_CONSUMER_KEY"] - consumer_secret: str = os.environ["EUMETSAT_CONSUMER_SECRET"] - token = eumdac.AccessToken(credentials=(consumer_key, consumer_secret)) - + token = _gen_token() results: list[pathlib.Path] = [] if len(scan_times) > cpu_count(): log.debug(f"Concurrency: {cpu_count()}") @@ -653,5 +661,5 @@ def _rewrite_zarr_times(output_name: str) -> None: new_average_secs_per_scan: int = int( (secs_per_scan + (runtime.total_seconds() / len(scan_times))) / 2, ) - cache.set("secs_per_scan", new_average_secs_per_scan) log.info(f"Completed archive for args: {args}. ({new_average_secs_per_scan} seconds per scan).") + diff --git a/containers/sat/test_download_process_sat.py b/containers/sat/test_download_process_sat.py new file mode 100644 index 0000000..1a4d127 --- /dev/null +++ b/containers/sat/test_download_process_sat.py @@ -0,0 +1,113 @@ +"""Tests for the satellite processing pipeline. + +Note that, since the files from EUMETSAT are so large, +they must be downloaded prior to running the tests - they +are too biug to include in the repository. As such, environment +variables must be set to authenticate with EUMETSAT +""" + +import datetime as dt +import pathlib +import unittest + +import download_process_sat as dps +import numpy as np +import pandas as pd +import xarray as xr +from satpy import Scene + + +class TestDownloadProcessSat(unittest.TestCase): + paths: list[pathlib.Path] + test_dataarrays: dict[str, xr.DataArray] + + @classmethod + def setUpClass(cls) -> None: + TIMESTAMP = pd.Timestamp("2024-01-01T00:00:00Z") + + token = dps._gen_token() + + paths = dps.download_scans( + sat_config=dps.CONFIGS["iodc"], + folder=pathlib.Path("/tmp/test_sat_data"), + scan_time=TIMESTAMP, + token=token, + ) + cls.paths = paths + + attrs: dict = { + "end_time": TIMESTAMP + pd.Timedelta("15m"), + "modifiers": (), + "orbital_parameters": {"projection_longitude": 45.5, "projection_latitude": 0.0, + "projection_altitude": 35785831.0, "satellite_nominal_longitude": 45.5, + "satellite_nominal_latitude": 0.0, "satellite_actual_longitude": 45.703605543834364, + "satellite_actual_latitude": 7.281469039541501, + "satellite_actual_altitude": 35788121.627292305}, + "reader": "seviri_l1b_native", + "sensor": "seviri", + "resolution": 3000.403165817, + "start_time": dt.datetime(2024, 1, 1, 0, 0, tzinfo=dt.UTC), + "platform_name": "Meteosat-9", "area": "Area ID: msg_seviri_iodc_3km", + } + + cls.test_dataarrays = { + "hrv": xr.DataArray( + data=np.random.random((1, 1, 3712, 3712)), + dims=["time", "variable", "x_geostationary", "y_geostationary"], + coords={ + "time": [pd.Timestamp("2024-01-01T00:00:00Z")], + "variable": ["HRV"], + "x_geostationary": np.arange(3712), + "y_geostationary": np.arange(3712), + }, + attrs=attrs, + ), + "nonhrv": xr.DataArray( + data=np.random.random((1, 11, 3712, 3712)), + dims=["time", "variable", "x_geostationary", "y_geostationary"], + coords={ + "time": [pd.Timestamp("2024-01-01T00:00:00Z")], + "variable": [c.variable for c in dps.CHANNELS["nonhrv"]], + "x_geostationary": np.arange(3712), + "y_geostationary": np.arange(3712), + }, + attrs=attrs, + ), + } + + def test_download_scans(self) -> None: + self.assertGreater(len(self.paths), 0) + + def test_convert_scene_to_dataarray(self) -> None: + scene = Scene(filenames={"seviri_l1b_native": [self.paths[0].as_posix()]}) + scene.load([c.variable for c in dps.CHANNELS["nonhrv"]]) + da = dps._convert_scene_to_dataarray( + scene, + band=dps.CHANNELS["nonhrv"][0].variable, + area="RSS", + calculate_osgb=False, + ) + + with self.subTest("Returned dataarray is correct shape"): + self.assertDictEqual( + dict(da.sizes), + {"time": 1, "variable": 11, "x_geostationary": 3712, "y_geostationary": 3712}, + ) + self.assertIn("end_time", da.attrs) + + def test_rescale(self) -> None: + da: xr.DataArray = dps._rescale(self.test_dataarrays["nonhrv"], channels=dps.CHANNELS["nonhrv"]) + + self.assertGreater(da.values.max(), 0) + self.assertLess(da.values.min(), 1) + self.assertEqual(da.attrs, self.test_dataarrays["nonhrv"].attrs) + + def test_open_and_scale_data(self) -> None: + ds: xr.Dataset | None = dps._open_and_scale_data([], self.paths[0].as_posix(), "nonhrv") + + if ds is None: + self.fail("Dataset is None") + + ds.to_zarr("/tmp/test_sat_data/test.zarr", mode="w", consolidated=True) + ds2 = xr.open_zarr("/tmp/test_sat_data/test.zarr") + self.assertDictEqual(dict(ds.sizes), dict(ds2.sizes))