diff --git a/pyproject.toml b/pyproject.toml index 513cdc78..5368202f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "returns == 0.23.0", "s3fs == 2024.9.0", "xarray == 2024.9.0", - "zarr == 2.18.2" + "zarr == 2.18.3" ] [dependency-groups] diff --git a/src/nwp_consumer/internal/entities/coordinates.py b/src/nwp_consumer/internal/entities/coordinates.py index 45a6d241..6019291e 100644 --- a/src/nwp_consumer/internal/entities/coordinates.py +++ b/src/nwp_consumer/internal/entities/coordinates.py @@ -37,7 +37,10 @@ import dataclasses import datetime as dt +import json +from importlib.metadata import PackageNotFoundError, version +import dask.array import numpy as np import pandas as pd import pytz @@ -46,6 +49,11 @@ from .parameters import Parameter +try: + __version__ = version("nwp-consumer") +except PackageNotFoundError: + __version__ = "v?" + @dataclasses.dataclass(slots=True) class NWPDimensionCoordinateMap: @@ -367,11 +375,57 @@ def default_chunking(self) -> dict[str, int]: that wants to cover the entire dimension should have a size equal to the dimension length. - It defaults to a single chunk per init time and step, and a single chunk - for each entire other dimension. + It defaults to a single chunk per init time and step, and 8 chunks + for each entire other dimension. These are purposefully small, to ensure + that when perfomring parallel writes, chunk boundaries are not crossed. """ out_dict: dict[str, int] = { "init_time": 1, "step": 1, - } | {dim: len(getattr(self, dim)) for dim in self.dims if dim not in ["init_time", "step"]} + } | { + dim: len(getattr(self, dim)) // 8 if len(getattr(self, dim)) > 8 else 1 + for dim in self.dims + if dim not in ["init_time", "step"] + } + return out_dict + + + def as_zeroed_dataarray(self, name: str) -> xr.DataArray: + """Express the coordinates as an xarray DataArray. + + Data is populated with zeros and a default chunking scheme is applied. + + Args: + name: The name of the DataArray. + + See Also: + - https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes + """ + # Create a dask array of zeros with the shape of the dataset + # * The values of this are ignored, only the shape and chunks are used + dummy_values = dask.array.zeros( # type: ignore + shape=list(self.shapemap.values()), + chunks=tuple([self.default_chunking()[k] for k in self.shapemap]), + ) + attrs: dict[str, str] = { + "produced_by": "".join(( + f"nwp-consumer {__version__} at ", + f"{dt.datetime.now(tz=dt.UTC).strftime('%Y-%m-%d %H:%M')}", + )), + "variables": json.dumps({ + p.value: { + "description": p.metadata().description, + "units": p.metadata().units, + } for p in self.variable + }), + } + # Create a DataArray object with the given coordinates and dummy values + da: xr.DataArray = xr.DataArray( + name=name, + data=dummy_values, + coords=self.to_pandas(), + attrs=attrs, + ) + return da + diff --git a/src/nwp_consumer/internal/entities/tensorstore.py b/src/nwp_consumer/internal/entities/tensorstore.py index 46113c1b..f0c3ce7c 100644 --- a/src/nwp_consumer/internal/entities/tensorstore.py +++ b/src/nwp_consumer/internal/entities/tensorstore.py @@ -7,18 +7,15 @@ This module provides a class for storing metadata about a Zarr store. """ +import abc import dataclasses import datetime as dt -import json import logging import os import pathlib import shutil -from importlib.metadata import PackageNotFoundError, version from typing import Any -import dask.array -import numpy as np import pandas as pd import xarray as xr import zarr @@ -30,11 +27,6 @@ log = logging.getLogger("nwp-consumer") -try: - __version__ = version("nwp-consumer") -except PackageNotFoundError: - __version__ = "v?" - @dataclasses.dataclass(slots=True) class ParameterScanResult: @@ -53,11 +45,11 @@ class ParameterScanResult: @dataclasses.dataclass(slots=True) -class TensorStore: +class TensorStore(abc.ABC): """Store class for multidimensional data. This class is used to store data in a Zarr store. - Each store instance is associated with a single init time, + Each store instance has defined coordinates for the data, and is capable of handling parallel, region-based updates. """ @@ -82,16 +74,13 @@ def initialize_empty_store( model: str, repository: str, coords: NWPDimensionCoordinateMap, - overwrite_existing: bool = True, ) -> ResultE["TensorStore"]: """Initialize a store for a given init time. This method writes a blank dataarray to disk based on the input coordinates, which define the dimension labels and tick values of the output dataset object. - If the store already exists, it will be overwritten, unless the 'overwrite_existing' - flag is set to False. In this case, the existing store will be used only if its - coordinates are consistent with the expected coordinates. + .. note: If a store already exists at the expected path, it will be overwritten! The dataarray is 'blank' because it is written via:: @@ -117,7 +106,6 @@ def initialize_empty_store( This is also used as the name of the tensor. repository: The name of the repository providing the tensor data. coords: The coordinates of the store. - overwrite_existing: Whether to overwrite an existing store. Returns: An indicator of a successful store write containing the number of bytes written. @@ -134,103 +122,72 @@ def initialize_empty_store( ValueError( "Cannot initialize store with 'init_time' dimension coordinates not " "specified via a populated list. Check instantiation of " - "NWPDimensionCoordinateMap. " + "NWPDimensionCoordinateMap passed to this function. " f"Got: {coords.init_time} (not a list, or empty).", ), ) - store_range: str = f"{coords.init_time[0]:%Y%m%d%H}" - if len(coords.init_time) > 1: - store_range = f"{coords.init_time[0]:%Y%m%d%H}-{coords.init_time[-1]:%Y%m%d%H}" - store_path = pathlib.Path( - os.getenv("ZARRDIR", f"~/.local/cache/nwp/{repository}/{model}/data"), - ) / f"{store_range}.zarr" - - # * Define a set of chunks allowing for intermediate parallel writes - # NOTE: This is not the same as the final chunking of the dataset! - # Merely a chunksize that is small enough to allow for parallel writes - # to different regions of the init store. - intermediate_chunks: dict[str, int] = { - "init_time": 1, - "step": 1, - "variable": 1, - "latitude": coords.shapemap.get("latitude", 400) // 4, - "longitude": coords.shapemap.get("longitude", 400) // 8, - "values": coords.shapemap.get("values", 100), - } - # Create a dask array of zeros with the shape of the dataset - # * The values of this are ignored, only the shape and chunks are used - dummy_values = dask.array.zeros( # type: ignore - shape=list(coords.shapemap.values()), - chunks=tuple([intermediate_chunks[k] for k in coords.shapemap]), - ) - attrs: dict[str, str] = { - "produced_by": "".join(( - f"nwp-consumer {__version__} at ", - f"{dt.datetime.now(tz=dt.UTC).strftime('%Y-%m-%d %H:%M')}", - )), - "variables": json.dumps({ - p.value: { - "description": p.metadata().description, - "units": p.metadata().units, - } for p in coords.variable - }), - } - # Create a DataArray object with the given coordinates and dummy values - da: xr.DataArray = xr.DataArray( - name=model, - data=dummy_values, - coords=coords.to_pandas(), - attrs=attrs, - ) - encoding: dict[str, Any] ={ - "init_time": {"units": "nanoseconds since 1970-01-01"}, - "step": {"units": "hours"}, - } - match (os.path.exists(store_path), overwrite_existing): - case (True, False): - store_da: xr.DataArray = xr.open_dataarray(store_path, engine="zarr") - for dim in store_da.dims: - if dim not in da.dims: - return Failure( - ValueError( - "Cannot use existing store due to mismatched coordinates. " - f"Dimension '{dim}' in existing store not found in new store. " - "Use 'overwrite_existing=True' or move the existing store at " - f"'{store_path}' to a new location. ", - ), - ) - if not np.array_equal(store_da.coords[dim].values, da.coords[dim].values): - return Failure( - ValueError( - "Cannot use existing store due to mismatched coordinates. " - f"Dimension '{dim}' in existing store has different coordinate " - "values from specified. " - "Use 'overwrite_existing=True' or move the existing store at " - f"'{store_path}' to a new location.", - ), - ) - case (_, _): + zarrdir = os.getenv("ZARRDIR", f"~/.local/cache/nwp/{repository}/{model}/data") + store: zarr.storage.Store + try: + store = zarr.storage.DirectoryStore( + pathlib.Path( + "/".join((zarrdir, TensorStore.gen_store_filename(coords=coords))), + ).expanduser().as_posix(), + ) + if zarrdir.startswith("s3"): + import s3fs + log.debug("Attempting AWS connection using credential discovery") try: - # Write the dataset to a skeleton zarr file - # * 'compute=False' enables only saving metadata - # * 'mode="w"' overwrites any existing store - _ = da.to_zarr( - store=store_path, - compute=False, - mode="w", - consolidated=True, - encoding=encoding, + fs = s3fs.S3FileSystem( + anon=False, + client_kwargs={ + "region_name": os.getenv("AWS_REGION", "eu-west-1"), + }, ) - # Ensure the store is readable - store_da = xr.open_dataarray(store_path, engine="zarr") + store = s3fs.mapping.S3Map(zarrdir, fs, check=True, create=False) except Exception as e: - return Failure( - OSError( - f"Failed writing blank store to disk: {e}", - ), - ) + return Failure(OSError( + f"Unable to create file mapping for ZARRDIR '{zarrdir}'. " + "Ensure ZARRDIR environment variable is specified correctly, " + "and AWS credentials are discoverable by botocore. " + f"Error context: {e}", + )) + except Exception as e: + return Failure(OSError( + f"Unable to create Zarr Store at dir '{zarrdir}'. " + "Ensure ZARRDIR environment variable is specified correctly. " + f"Error context: {e}", + )) + + # Write the coordinates to a skeleton Zarr store + # * 'compute=False' enables only saving metadata + # * 'mode="w"' overwrites any existing store + log.info("Initialising zarr store at '%s'", store) + da: xr.DataArray = coords.as_zeroed_dataarray(name=model) + encoding = { + model: {"write_empty_chunks": False}, + "init_time": {"units": "nanoseconds since 1970-01-01"}, + "step": {"units": "hours"}, + } + try: + _ = da.to_zarr( + store=store, + compute=False, + mode="w", + consolidated=True, + encoding=encoding, + ) + # Ensure the store is readable + store_da = xr.open_dataarray(store, engine="zarr") + except Exception as e: + return Failure( + OSError( + f"Failed writing blank store to disk: {e}", + ), + ) + # Check the resultant array's coordinates can be converted back coordinate_map_result = NWPDimensionCoordinateMap.from_xarray(store_da) if isinstance(coordinate_map_result, Failure): @@ -244,13 +201,42 @@ def initialize_empty_store( return Success( cls( name=model, - path=store_path, + path=store.path, coordinate_map=coordinate_map_result.unwrap(), size_kb=0, encoding=encoding, ), ) + #def from_existing_store( + # model: str, + # repository: str, + # expected_coords: NWPDimensionCoordinateMap, + #) -> ResultE["TensorStore"]: + # """Create a TensorStore instance from an existing store.""" + # pass # TODO + + # for dim in store_da.dims: + # if dim not in da.dims: + # return Failure( + # ValueError( + # "Cannot use existing store due to mismatched coordinates. " + # f"Dimension '{dim}' in existing store not found in new store. " + # "Use 'overwrite_existing=True' or move the existing store at " + # f"'{store}' to a new location. ", + # ), + # ) + # if not np.array_equal(store_da.coords[dim].values, da.coords[dim].values): + # return Failure( + # ValueError( + # "Cannot use existing store due to mismatched coordinates. " + # f"Dimension '{dim}' in existing store has different coordinate " + # "values from specified. " + # "Use 'overwrite_existing=True' or move the existing store at " + # f"'{store}' to a new location.", + # ), + # ) + # --- Business logic methods --- # def write_to_region( self, @@ -376,83 +362,13 @@ def postprocess(self, options: PostProcessOptions) -> ResultE[pathlib.Path]: This creates a new store, as many of the postprocess options require modifications to the underlying file structure of the store. """ + # TODO: Implement postprocessing options if options.requires_postprocessing(): log.info("Applying postprocessing options to store %s", self.name) if options.validate: log.warning("Validation not yet implemented in efficient manner. Skipping option.") - store_da: xr.DataArray = xr.open_dataarray( - self.path, - engine="zarr", - ) - - if options.codec: - log.debug("Applying codec %s to store %s", options.codec.name, self.name) - self.encoding = self.encoding | {"compressor": options.codec.value} - - if options.rechunk: - store_da = store_da.chunk(chunks=self.coordinate_map.default_chunking()) - - if options.standardize_coordinates: - # Make the longitude values range from -180 to 180 - store_da = store_da.assign_coords({ - "longitude": ((store_da.coords["longitude"] + 180) % 360) - 180, - }) - # Find the index of the maximum value - idx: int = store_da.coords["longitude"].argmax().values - # Move the maximum value to the end, and do the same to the underlying data - store_da = store_da.roll( - longitude=len(store_da.coords["longitude"]) - idx - 1, - roll_coords=True, - ) - coordinates_result = NWPDimensionCoordinateMap.from_xarray(store_da) - match coordinates_result: - case Failure(e): - return Failure(e) - case Success(coords): - self.coordinate_map = coords - - if options.requires_rewrite(): - processed_path = self.path.parent / (self.path.name + ".processed") - try: - log.debug( - "Writing postprocessed store to %s", - processed_path, - ) - # Clear the encoding for any variables indexed as an 'object' type - # * e.g. Dimensions with string labels -> the variable dim - # * See https://github.com/sgkit-dev/sgkit/issues/991 - # * and https://github.com/pydata/xarray/issues/3476 - store_da.coords["variable"].encoding.clear() - _ = store_da.to_zarr( - store=processed_path, - mode="w", - encoding=self.encoding, - consolidated=True, - ) - self.path = processed_path - except Exception as e: - return Failure( - OSError( - f"Error encountered writing postprocessed store: {e}", - ), - ) - - if options.zip: - log.debug( - "Postprocessor: Zipping store to " - f"{self.path.with_suffix(".zarr.zip")}", - ) - try: - shutil.make_archive(self.path.name, "zip", self.path) - except Exception as e: - return Failure( - OSError( - f"Error encountered zipping store: {e}", - ), - ) - log.debug("Postprocessing complete for store %s", self.name) return Success(self.path) @@ -492,4 +408,17 @@ def missing_times(self) -> ResultE[list[dt.datetime]]: missing_times.append(pd.Timestamp(it).to_pydatetime().replace(tzinfo=dt.UTC)) return Success(missing_times) + @staticmethod + def gen_store_filename(coords: NWPDimensionCoordinateMap) -> str: + """Create a filename for the store. + + If the store only covers a single init_time, the filename is the init time. + Else, if it covers multiple init_times, the filename is the range of init times. + The extension is '.zarr'. + """ + store_range: str = coords.init_time[0].strftime("%Y%m%d%H") + if len(coords.init_time) > 1: + store_range = f"{coords.init_time[0]:%Y%m%d%H}-{coords.init_time[-1]:%Y%m%d%H}" + + return store_range + ".zarr" diff --git a/src/nwp_consumer/internal/entities/test_tensorstore.py b/src/nwp_consumer/internal/entities/test_tensorstore.py index 84db8106..03d8ecf5 100644 --- a/src/nwp_consumer/internal/entities/test_tensorstore.py +++ b/src/nwp_consumer/internal/entities/test_tensorstore.py @@ -65,13 +65,6 @@ class TestCase: options=PostProcessOptions(), should_error=False, ), - TestCase( - name="standardize_coordinates", - options=PostProcessOptions( - standardize_coordinates=True, - ), - should_error=False, - ), ] for t in tests: