Skip to content

Commit

Permalink
feat(sat-container): Start adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
devsjc committed Aug 13, 2024
1 parent 4916735 commit aea9a31
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 88 deletions.
2 changes: 1 addition & 1 deletion containers/sat/Containerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ FROM quay.io/condaforge/miniforge3:latest AS build-venv
RUN apt -qq update && apt -qq install -y build-essential
RUN conda create -p /venv python=3.12
RUN /venv/bin/pip install --upgrade -q pip wheel setuptools
RUN conda install -p /venv -c conda-forge -y cartopy satpy[all] numpy
RUN conda install -p /venv -c conda-forge -y cartopy satpy[all]=0.50.0 numpy
ENV GDAL_CONFIG=/venv/bin/gdal-config

# Build the virtualenv
Expand Down
182 changes: 95 additions & 87 deletions containers/sat/download_process_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
Consolidates the old cli_downloader, backfill_hrv and backfill_nonhrv scripts.
"""

import argparse
import dataclasses
import datetime as dt
import itertools
import json
import traceback
import logging
import os
import pathlib
Expand All @@ -15,15 +17,17 @@
from multiprocessing import Pool, cpu_count
from typing import Literal

import diskcache as dc
import eumdac
import eumdac.cli
import numpy as np
import pandas as pd
import pyproj
import pyresample
import satpy.dataset.dataid
import xarray as xr
import yaml
import dask.delayed
import dask.distributed
import zarr
from ocf_blosc2 import Blosc2
from satpy import Scene
Expand Down Expand Up @@ -230,15 +234,12 @@ def process_scans(
"""
# Check zarr file exists for the year
zarr_path: pathlib.Path = folder.parent / start.strftime(sat_config.zarr_fmtstr[dstype])
zarr_times: list[dt.datetime] = []
if zarr_path.exists():
zarr_times: list[dt.datetime] = xr.open_zarr(zarr_path).sortby("time").time.values.tolist()
last_zarr_time: dt.datetime = zarr_times[-1]
zarr_times = xr.open_zarr(zarr_path).sortby("time").time.values.tolist()
log.debug(f"Zarr store already exists at {zarr_path} for {zarr_times[0]}-{zarr_times[-1]}")
else:
# Set dummy values for times already in zarr
last_zarr_time = dt.datetime(1970, 1, 1, tzinfo=dt.UTC)
zarr_times = [last_zarr_time, last_zarr_time]
log.debug(f"Zarr store does not exist at {zarr_path}, setting dummy times")
log.debug(f"Zarr store does not exist at {zarr_path}")

# Get native files in order
native_files: list[pathlib.Path] = list(folder.glob("*.nat"))
Expand Down Expand Up @@ -268,29 +269,42 @@ def process_scans(
else:
log.debug(f"Creating new zarr store at {zarr_path}")
mode = "w"
concat_ds: xr.Dataset = xr.concat(datasets, dim="time")
_write_to_zarr(
xr.concat(datasets, dim="time"),
concat_ds,
zarr_path.as_posix(),
mode,
chunks={"time": 12},
chunks={
"time": 12,
},
)
datasets = []

log.info(f"Process loop [{dstype}]: {i+1}/{len(native_files)}")

# Consolidate zarr metadata
_rewrite_zarr_times(zarr_path.as_posix())
if pathlib.Path(zarr_path).exists():
_rewrite_zarr_times(zarr_path.as_posix())

return dstype


def _gen_token() -> eumdac.AccessToken:
"""Generated an aces token from environment variables."""
consumer_key: str = os.environ["EUMETSAT_CONSUMER_KEY"]
consumer_secret: str = os.environ["EUMETSAT_CONSUMER_SECRET"]
token = eumdac.AccessToken(credentials=(consumer_key, consumer_secret))

return token


def _convert_scene_to_dataarray(
scene: Scene,
band: str,
area: str,
calculate_osgb: bool = True,
) -> xr.DataArray:
"""Convertes a Scene with satellite data into a data array.
"""Converts a Scene with satellite data into a data array.
Args:
scene: The satpy.Scene containing the satellite data
Expand All @@ -314,14 +328,24 @@ def _convert_scene_to_dataarray(
log.debug("Finished resample")
scene = scene.crop(ll_bbox=GEOGRAPHIC_BOUNDS[area])
log.debug("Finished crop")
# Remove acq time from all bands because it is not useful, and can actually
# get in the way of combining multiple Zarr datasets.

# Update the dataarray attributes based off of the satpy scene attributes
data_attrs = {}
for channel in scene.wishlist:
# Remove acq time from all bands because it is not useful, and can actually
# get in the way of combining multiple Zarr datasets.
scene[channel] = scene[channel].drop_vars("acq_time", errors="ignore")
for attr in scene[channel].attrs:
new_name = channel["name"] + "_" + attr
data_attrs[new_name] = scene[channel].attrs[attr]
# Ignore the "area" and "_satpy_id" scene attributes as they are not serializable
# and their data is already present in other scene attrs anyway.
if attr not in ["area", "_satpy_id"]:
try:
serialized_value = json.dumps(scene[channel].attrs[attr])
data_attrs[new_name] = serialized_value
except Exception as e:
log.warning(f"Could not serialize scene attribute {new_name}: {e}")

dataset: xr.Dataset = scene.to_xarray_dataset()
dataarray = dataset.to_array()
log.debug("Converted to dataarray")
Expand All @@ -348,8 +372,8 @@ def _convert_scene_to_dataarray(
dataarray[name].attrs["coordinate_reference_system"] = "geostationary"
log.debug("Calculated OSGB")
# Round to the nearest 5 minutes
dataarray.attrs.update(data_attrs)
dataarray.attrs["end_time"] = pd.Timestamp(dataarray.attrs["end_time"]).round("5 min")
data_attrs["end_time"] = pd.Timestamp(dataarray.attrs["end_time"]).round("5 min").__str__()
dataarray.attrs = data_attrs

# Rename x and y to make clear the coordinate system they are in
dataarray = dataarray.rename({"x": "x_geostationary", "y": "y_geostationary"})
Expand All @@ -363,41 +387,7 @@ def _convert_scene_to_dataarray(
return dataarray


def _serialize_attrs(attrs: dict) -> dict:
"""Ensure each value of dict can be serialized.
This is required before saving to Zarr because Zarr represents attrs values in a
JSON file (.zmetadata).
The `area` field (which is a `pyresample.geometry.AreaDefinition` object gets turned
into a YAML string, which can be loaded again using
`area_definition = pyresample.area_config.load_area_from_string(data_array.attrs['area'])`
Returns attrs dict where every value has been made serializable.
"""
for key, value in attrs.items():
# Convert Dicts
if isinstance(value, dict):
# Convert np.float32 to Python floats (otherwise yaml.dump complains)
for inner_key in value:
inner_value = value[inner_key]
if isinstance(inner_value, np.floating):
value[inner_key] = float(inner_value)
attrs[key] = yaml.dump(value)
# Convert Numpy bools
if isinstance(value, bool | np.bool_):
attrs[key] = str(value)
# Convert area
if isinstance(value, pyresample.geometry.AreaDefinition):
attrs[key] = value.dump()
# Convert datetimes
if isinstance(value, dt.datetime):
attrs[key] = value.isoformat()

return attrs


def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray | None:
def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray:
"""Rescale Xarray DataArray so all values lie in the range [0, 1].
Warning: The original `dataarray` will be modified in-place.
Expand All @@ -420,11 +410,12 @@ def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray |
"variable",
)

# For each channel, subtract the minimum and divide by the range
dataarray -= [c.minimum for c in channels]
dataarray /= [c.maximum - c.minimum for c in channels]
# Since the mins and maxes are approximations, clip the values to [0, 1]
dataarray = dataarray.clip(min=0, max=1)
dataarray = dataarray.astype(np.float32)
dataarray.attrs = _serialize_attrs(dataarray.attrs) # Must be serializable
return dataarray


Expand All @@ -433,30 +424,45 @@ def _open_and_scale_data(
f: str,
dstype: Literal["hrv", "nonhrv"],
) -> xr.Dataset | None:
"""Opens a raw file and converts it to a normalised xarray dataset."""
"""Opens a raw file and converts it to a normalised xarray dataset.
Args:
zarr_times: List of times already in the zarr store.
f: Path to the file to open.
dstype: Type of data to process (hrv or nonhrv).
"""
# The reader is the same for each satellite as the sensor is the same
# * Hence "severi" in all cases
# * Hence "seviri" in all cases
scene = Scene(filenames={"seviri_l1b_native": [f]})
scene.load([c.variable for c in CHANNELS[dstype]])
da: xr.DataArray = _convert_scene_to_dataarray(
scene,
band=CHANNELS[dstype][0].variable,
area="RSS",
calculate_osgb=False,
)

# Rescale the data, update the attributes, save as dataset
attrs = _serialize_attrs(da.attrs)
da = _rescale(da, CHANNELS[dstype])
da.attrs.update(attrs)
da = da.transpose("time", "y_geostationary", "x_geostationary", "variable")
ds: xr.Dataset = da.to_dataset(name="data")
ds["data"] = ds.data.astype(np.float16)
try:
da: xr.DataArray = _convert_scene_to_dataarray(
scene,
band=CHANNELS[dstype][0].variable,
area="RSS",
calculate_osgb=False,
)
except Exception as e:
log.error(f"Error converting scene to dataarray: {e}")
return None

if ds.time.values[0] in zarr_times:
log.debug(f"Skipping: {ds.time.values[0]}")
# Don't proceed if the dataarray time is already present in the zarr store
if da.time.values[0] in zarr_times:
log.debug(f"Skipping: {da.time.values[0]}")
return None

# Rescale the data, save as dataset
try:
da = _rescale(da, CHANNELS[dstype])
except Exception as e:
log.error(f"Error rescaling dataarray: {e}")
return None

da = da.transpose("time", "y_geostationary", "x_geostationary", "variable")
ds: xr.Dataset = da.to_dataset(name="data", promote_attrs=True)
ds["data"] = ds["data"].astype(np.float16)

return ds


Expand All @@ -465,12 +471,12 @@ def _preprocess_function(xr_data: xr.Dataset) -> xr.Dataset:
attrs = xr_data.attrs
y_coords = xr_data.coords["y_geostationary"].values
x_coords = xr_data.coords["x_geostationary"].values
x_dataarray = xr.DataArray(
x_dataarray: xr.DataArray = xr.DataArray(
data=np.expand_dims(xr_data.coords["x_geostationary"].values, axis=0),
dims=["time", "x_geostationary"],
coords={"time": xr_data.coords["time"].values, "x_geostationary": x_coords},
)
y_dataarray = xr.DataArray(
y_dataarray: xr.DataArray = xr.DataArray(
data=np.expand_dims(xr_data.coords["y_geostationary"].values, axis=0),
dims=["time", "y_geostationary"],
coords={"time": xr_data.coords["time"].values, "y_geostationary": y_coords},
Expand Down Expand Up @@ -498,21 +504,25 @@ def _write_to_zarr(dataset: xr.Dataset, zarr_name: str, mode: str, chunks: dict)
extra_kwargs = mode_extra_kwargs[mode]
sliced_ds: xr.Dataset = dataset.isel(x_geostationary=slice(0, 5548)).chunk(chunks)
try:
sliced_ds.to_zarr(
write_job: dask.delayed.Delayed = sliced_ds.to_zarr(
store=zarr_name,
compute=True,
compute=False,
**extra_kwargs,
consolidated=True,
mode=mode,
**extra_kwargs,
)
write_job = write_job.persist()
dask.distributed.progress(write_job, notebook=False)
except Exception as e:
log.error(f"Error writing to zarr: {e}")
log.error(f"Error writing dataset to zarr store {zarr_name} with mode {mode}: {e}")
traceback.print_tb(e.__traceback__)
return None


def _rewrite_zarr_times(output_name: str) -> None:
"""Rewrites the time coordinates in the given zarr store."""
# Combine time coords
ds = xr.open_zarr(output_name)
ds = xr.open_zarr(output_name, consolidated=False)

# Prevent numcodecs string error
# See https://github.com/pydata/xarray/issues/3476#issuecomment-1205346130
Expand All @@ -524,6 +534,10 @@ def _rewrite_zarr_times(output_name: str) -> None:
ds[v].encoding.clear()

del ds["data"]
if "x_geostationary_coordinates" in ds:
del ds["x_geostationary_coordinates"]
if "y_geostationary_coordinates" in ds:
del ds["y_geostationary_coordinates"]
# Need to remove these encodings to avoid chunking
del ds.time.encoding["chunks"]
del ds.time.encoding["preferred_chunks"]
Expand All @@ -541,7 +555,7 @@ def _rewrite_zarr_times(output_name: str) -> None:
data["metadata"]["time/.zarray"] = coord_data["metadata"]["time/.zarray"]
with open(f"{output_name}/.zmetadata", "w") as f:
json.dump(data, f)
zarr.consolidate_metadata(output_name)
# zarr.consolidate_metadata(output_name)


parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -583,9 +597,6 @@ def _rewrite_zarr_times(output_name: str) -> None:
args = parser.parse_args()
folder: pathlib.Path = args.path / args.sat

# Create a reusable cache
cache = dc.Cache(folder / ".cache/{args.sat}")

log.info(f"{prog_start!s}: Running with args: {args}")
# Get config for desired satellite
sat_config = CONFIGS[args.sat]
Expand All @@ -599,17 +610,14 @@ def _rewrite_zarr_times(output_name: str) -> None:
freq=sat_config.cadence,
).tolist()

# Get average runtime from cache
secs_per_scan = cache.get("secs_per_scan", default=90)
# Estimate average runtime
secs_per_scan: int = 90
expected_runtime = pd.Timedelta(secs_per_scan * len(scan_times), "seconds")
log.info(f"Downloading {len(scan_times)} scans. Expected runtime: {expected_runtime!s}")

# Download data
# We only parallelize if we have a number of files larger than the cpu count
consumer_key: str = os.environ["EUMETSAT_CONSUMER_KEY"]
consumer_secret: str = os.environ["EUMETSAT_CONSUMER_SECRET"]
token = eumdac.AccessToken(credentials=(consumer_key, consumer_secret))

token = _gen_token()
results: list[pathlib.Path] = []
if len(scan_times) > cpu_count():
log.debug(f"Concurrency: {cpu_count()}")
Expand Down Expand Up @@ -653,5 +661,5 @@ def _rewrite_zarr_times(output_name: str) -> None:
new_average_secs_per_scan: int = int(
(secs_per_scan + (runtime.total_seconds() / len(scan_times))) / 2,
)
cache.set("secs_per_scan", new_average_secs_per_scan)
log.info(f"Completed archive for args: {args}. ({new_average_secs_per_scan} seconds per scan).")

Loading

0 comments on commit aea9a31

Please sign in to comment.