Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes preserving coordinates in regrid2 #716

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 76 additions & 22 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import xarray as xr

import xcdat as xc
from xcdat.axis import get_dim_keys
from xcdat.regridder.base import BaseRegridder, _preserve_bounds

Expand Down Expand Up @@ -105,8 +106,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
ds,
data_var,
output_data,
dst_lat_bnds,
dst_lon_bnds,
self._input_grid,
self._output_grid,
)
Expand Down Expand Up @@ -228,38 +227,90 @@ def _build_dataset(
ds: xr.Dataset,
data_var: str,
output_data: np.ndarray,
dst_lat_bnds,
dst_lon_bnds,
input_grid: xr.Dataset,
output_grid: xr.Dataset,
) -> xr.Dataset:
input_data_var = ds[data_var]
"""Build a new xarray Dataset with the given output data and coordinates.

Parameters
----------
ds : xr.Dataset
The input dataset containing the data variable to be regridded.
data_var : str
The name of the data variable in the input dataset to be regridded.
output_data : np.ndarray
The regridded data to be included in the output dataset.
input_grid : xr.Dataset
The input grid dataset containing the original grid information.
output_grid : xr.Dataset
The output grid dataset containing the new grid information.

output_coords: dict[str, xr.DataArray] = {}
output_data_vars: dict[str, xr.DataArray] = {}
Returns
-------
xr.Dataset
A new dataset containing the regridded data variable with updated
coordinates and attributes.
"""
dv_input = ds[data_var]

dims = list(input_data_var.dims)
output_coords = _get_output_coords(dv_input, output_grid)

output_da = xr.DataArray(
output_data,
dims=dims,
dims=dv_input.dims,
coords=output_coords,
attrs=ds[data_var].attrs.copy(),
name=data_var,
)

output_data_vars[data_var] = output_da

output_ds = xr.Dataset(
output_data_vars,
attrs=input_grid.attrs.copy(),
)

output_ds = output_da.to_dataset()
output_ds.attrs = input_grid.attrs.copy()
output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"])

return output_ds


def _get_output_coords(
dv_input: xr.DataArray, output_grid: xr.Dataset
) -> Dict[str, xr.DataArray]:
"""
Generate the output coordinates for regridding based on the input data
variable and output grid.

Parameters
----------
dv_input : xr.DataArray
The input data variable containing the original coordinates.
output_grid : xr.Dataset
The dataset containing the target grid coordinates.

Returns
-------
Dict[str, xr.DataArray]
A dictionary where keys are coordinate names and values are the
corresponding coordinates from the output grid or input data variable,
aligned with the dimensions of the input data variable.
"""
output_coords: Dict[str, xr.DataArray] = {}

# First get the X and Y axes from the output grid.
for key in ["X", "Y"]:
input_coord = xc.get_dim_coords(dv_input, key) # type: ignore
output_coord = xc.get_dim_coords(output_grid, key) # type: ignore

output_coords[str(input_coord.name)] = output_coord # type: ignore

# Get the remaining axes the input data variable (e.g., "time").
for dim in dv_input.dims:
if dim not in output_coords:
output_coords[str(dim)] = dv_input[dim]

# Sort the coords to align with the input data variable dims.
output_coords = {str(dim): output_coords[str(dim)] for dim in dv_input.dims}

return output_coords

Comment on lines +294 to +311
Copy link
Collaborator

@tomvothecoder tomvothecoder Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My logic gets the X and Y axes from the output_grid via xc.get_dim_coords(). This function can map to axes via "axis" attr, "standard_name" attr, and accepted dim names (e.g., lat, lon).

For remaining axes, it just gets them directly from the input data variable (dv_input) like in your logic.


def _map_latitude(
src: np.ndarray, dst: np.ndarray
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
Expand Down Expand Up @@ -553,12 +604,15 @@ def _get_dimension(input_data_var, cf_axis_name):


def _get_bounds_ensure_dtype(ds, axis):
bounds = None

try:
name = ds.cf.bounds[axis][0]
except (KeyError, IndexError) as e:
raise RuntimeError(f"Could not determine {axis!r} bounds") from e
else:
bounds = ds[name]
bounds = ds.bounds.get_bounds(axis)
except KeyError:
pass

if bounds is None:
raise RuntimeError(f"Could not determine {axis!r} bounds")

if bounds.dtype != np.float32:
bounds = bounds.astype(np.float32)
Expand Down
Loading