Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: generalize open_dataset for non-box meshes #13

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 63 additions & 22 deletions pymech/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import xarray as xr

from .neksuite import readnek
from .log import logger


def open_dataset(path, **kwargs):
Expand Down Expand Up @@ -50,6 +51,7 @@ def open_mfdataset(
parallel=False,
join="outer",
attrs_file=None,
box_mesh=True,
**kwargs,
):
"""Helper function for opening multiple files as an xarray_ dataset.
Expand Down Expand Up @@ -79,7 +81,7 @@ def open_mfdataset(
if isinstance(concat_dim, (str, xr.DataArray)) or concat_dim is None:
concat_dim = [concat_dim]

open_kwargs = dict()
open_kwargs = dict(box_mesh=box_mesh)

if parallel:
import dask
Expand Down Expand Up @@ -140,44 +142,55 @@ def open_mfdataset(
return combined


def _open_nek_dataset(path):
def _open_nek_dataset(path, box_mesh=True):
"""Interface for converting Nek field files into xarray_ datasets."""
field = readnek(path)
if isinstance(field, int):
raise IOError("Failed to load {}".format(path))

elements = field.elem
elem_stores = [_NekDataStore(elem) for elem in elements]
elem_dsets = [
xr.Dataset.load_store(store).set_coords(store.axes)
for store in elem_stores
]
if box_mesh:
DataStore = _NekDataStoreBox
coords = ('z', 'y', 'x')
else:
DataStore = _NekDataStoreGeneric
coords = ('zmesh', 'ymesh', 'xmesh')

elem_stores = [DataStore(elem) for elem in elements]
try:
elem_dsets = [
xr.Dataset.load_store(store).set_coords(coords)
for store in elem_stores
]
except ValueError:
logger.critical(
"Cannot load as a box mesh. Try passing the keyword argument: "
"`box_mesh=False` while opening the dataset."
)

1/0
# See: https://github.com/MITgcm/xmitgcm/pull/200
if xr.__version__ < '0.15.2':
ds = xr.combine_by_coords(elem_dsets)
if box_mesh:
if xr.__version__ < '0.15.2':
ds = xr.combine_by_coords(elem_dsets)
else:
ds = xr.combine_by_coords(elem_dsets, combine_attrs="drop")
else:
ds = xr.combine_by_coords(elem_dsets, combine_attrs="drop")
# This would work if we pre-sort elem_dsets and pass as a nested list /
# iterator
ds = xr.combine_nested(elem_dsets, concat_dim=('z', 'y', 'x'))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails here. Find a way to "reshape" the list of element datasets.


ds.coords.update({"time": field.time})

return ds


class _NekDataStore(xr.backends.common.AbstractDataStore):
class _NekDataStoreGeneric(xr.backends.common.AbstractDataStore):
"""Xarray store for a Nek field element."""
def __init__(self, elem):
self.elem = elem
self.axes = ("z", "y", "x")

def meshgrid_to_dim(self, mesh):
"""Reverse of np.meshgrid. This method extracts one-dimensional
coordinates from a cubical array format for every direction

"""
dim = np.unique(np.round(mesh, 8))
return dim

def get_dimensions(self):
return self.axes

Expand All @@ -196,9 +209,6 @@ def get_variables(self):
elem = self.elem

data_vars = {
ax[2]: self.meshgrid_to_dim(elem.pos[0]), # x
ax[1]: self.meshgrid_to_dim(elem.pos[1]), # y
ax[0]: self.meshgrid_to_dim(elem.pos[2]), # z
"xmesh": (ax, elem.pos[0]),
"ymesh": (ax, elem.pos[1]),
"zmesh": (ax, elem.pos[2]),
Expand All @@ -221,3 +231,34 @@ def get_variables(self):
)

return data_vars


class _NekDataStoreBox(_NekDataStoreGeneric):
"""Useful for box meshes where coordinates can be simplified into 1D
arrays.

"""

def meshgrid_to_dim(self, mesh):
"""Reverse of np.meshgrid. This method extracts one-dimensional
coordinates from a cubical array format for every direction

"""
dim = np.unique(np.round(mesh, 8))
return dim

def get_variables(self):
"""Generate an xarray dataset from a single element."""
ax = self.axes
elem = self.elem

data_vars = super().get_variables()

data_vars.update(
{
ax[2]: self.meshgrid_to_dim(elem.pos[0]), # x
ax[1]: self.meshgrid_to_dim(elem.pos[1]), # y
ax[0]: self.meshgrid_to_dim(elem.pos[2]), # z
}
)
return data_vars