Skip to content

Commit

Permalink
feat(tensorstore): Move coordinate functions to coordinate
Browse files Browse the repository at this point in the history
  • Loading branch information
devsjc committed Nov 8, 2024
1 parent c705646 commit 52733e4
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 188 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"returns == 0.23.0",
"s3fs == 2024.9.0",
"xarray == 2024.9.0",
"zarr == 2.18.2"
"zarr == 2.18.3"
]

[dependency-groups]
Expand Down
60 changes: 57 additions & 3 deletions src/nwp_consumer/internal/entities/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@

import dataclasses
import datetime as dt
import json
from importlib.metadata import PackageNotFoundError, version

import dask.array
import numpy as np
import pandas as pd
import pytz
Expand All @@ -46,6 +49,11 @@

from .parameters import Parameter

try:
__version__ = version("nwp-consumer")
except PackageNotFoundError:
__version__ = "v?"


@dataclasses.dataclass(slots=True)
class NWPDimensionCoordinateMap:
Expand Down Expand Up @@ -367,11 +375,57 @@ def default_chunking(self) -> dict[str, int]:
that wants to cover the entire dimension should have a size equal to the
dimension length.
It defaults to a single chunk per init time and step, and a single chunk
for each entire other dimension.
It defaults to a single chunk per init time and step, and 8 chunks
for each entire other dimension. These are purposefully small, to ensure
that when perfomring parallel writes, chunk boundaries are not crossed.
"""
out_dict: dict[str, int] = {
"init_time": 1,
"step": 1,
} | {dim: len(getattr(self, dim)) for dim in self.dims if dim not in ["init_time", "step"]}
} | {
dim: len(getattr(self, dim)) // 8 if len(getattr(self, dim)) > 8 else 1
for dim in self.dims
if dim not in ["init_time", "step"]
}

return out_dict


def as_zeroed_dataarray(self, name: str) -> xr.DataArray:
"""Express the coordinates as an xarray DataArray.
Data is populated with zeros and a default chunking scheme is applied.
Args:
name: The name of the DataArray.
See Also:
- https://docs.xarray.dev/en/stable/user-guide/io.html#distributed-writes
"""
# Create a dask array of zeros with the shape of the dataset
# * The values of this are ignored, only the shape and chunks are used
dummy_values = dask.array.zeros( # type: ignore
shape=list(self.shapemap.values()),
chunks=tuple([self.default_chunking()[k] for k in self.shapemap]),
)
attrs: dict[str, str] = {
"produced_by": "".join((
f"nwp-consumer {__version__} at ",
f"{dt.datetime.now(tz=dt.UTC).strftime('%Y-%m-%d %H:%M')}",
)),
"variables": json.dumps({
p.value: {
"description": p.metadata().description,
"units": p.metadata().units,
} for p in self.variable
}),
}
# Create a DataArray object with the given coordinates and dummy values
da: xr.DataArray = xr.DataArray(
name=name,
data=dummy_values,
coords=self.to_pandas(),
attrs=attrs,
)
return da

Loading

0 comments on commit 52733e4

Please sign in to comment.