Skip to content

Commit

Permalink
CFTime support for polyval (#6624)
Browse files Browse the repository at this point in the history
Co-authored-by: dcherian <deepak@cherian.net>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
4 people authored May 31, 2022
1 parent 95a47af commit 4c92d52
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 14 deletions.
31 changes: 23 additions & 8 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .options import OPTIONS, _get_keep_attrs
from .pycompat import is_duck_dask_array
from .types import T_DataArray
from .utils import is_dict_like
from .utils import is_dict_like, is_scalar
from .variable import Variable

if TYPE_CHECKING:
Expand Down Expand Up @@ -1887,6 +1887,15 @@ def polyval(coord: Dataset, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
...


@overload
def polyval(
coord: Dataset | DataArray,
coeffs: Dataset | DataArray,
degree_dim: Hashable = "degree",
) -> Dataset | DataArray:
...


def polyval(
coord: Dataset | DataArray,
coeffs: Dataset | DataArray,
Expand Down Expand Up @@ -1953,15 +1962,21 @@ def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray:
"""
from .dataset import Dataset

def _cfoffset(x: DataArray) -> Any:
scalar = x.compute().data[0]
if not is_scalar(scalar):
# we do not get a scalar back on dask == 2021.04.1
scalar = scalar.item()
return type(scalar)(1970, 1, 1)

def to_floatable(x: DataArray) -> DataArray:
if x.dtype.kind == "M":
# datetimes
if x.dtype.kind in "MO":
# datetimes (CFIndexes are object type)
offset = (
np.datetime64("1970-01-01") if x.dtype.kind == "M" else _cfoffset(x)
)
return x.copy(
data=datetime_to_numeric(
x.data,
offset=np.datetime64("1970-01-01"),
datetime_unit="ns",
),
data=datetime_to_numeric(x.data, offset=offset, datetime_unit="ns"),
)
elif x.dtype.kind == "m":
# timedeltas
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
# This map_blocks call is for backwards compatibility.
# dask == 2021.04.1 does not support subtracting object arrays
# which is required for cftime
if is_duck_dask_array(array) and np.issubdtype(array.dtype, np.object):
if is_duck_dask_array(array) and np.issubdtype(array.dtype, object):
array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta)
else:
array = array - offset
Expand Down
97 changes: 92 additions & 5 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from xarray.core.pycompat import dask_version
from xarray.core.types import T_Xarray

from . import has_dask, raise_if_dask_computes, requires_dask
from . import (
has_cftime,
has_dask,
raise_if_dask_computes,
requires_cftime,
requires_dask,
)


def assert_identical(a, b):
Expand Down Expand Up @@ -1936,7 +1942,9 @@ def test_where_attrs() -> None:
assert actual.attrs == {}


@pytest.mark.parametrize("use_dask", [False, True])
@pytest.mark.parametrize(
"use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")]
)
@pytest.mark.parametrize(
["x", "coeffs", "expected"],
[
Expand Down Expand Up @@ -2031,20 +2039,99 @@ def test_polyval(
pytest.skip("requires dask")
coeffs = coeffs.chunk({"degree": 2})
x = x.chunk({"x": 2})

with raise_if_dask_computes():
actual = xr.polyval(coord=x, coeffs=coeffs) # type: ignore
actual = xr.polyval(coord=x, coeffs=coeffs)

xr.testing.assert_allclose(actual, expected)


@requires_cftime
@pytest.mark.parametrize(
"use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")]
)
@pytest.mark.parametrize("date", ["1970-01-01", "0753-04-21"])
def test_polyval_cftime(use_dask: bool, date: str) -> None:
import cftime

x = xr.DataArray(
xr.date_range(date, freq="1S", periods=3, use_cftime=True),
dims="x",
)
coeffs = xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]})

if use_dask:
if not has_dask:
pytest.skip("requires dask")
coeffs = coeffs.chunk({"degree": 2})
x = x.chunk({"x": 2})

with raise_if_dask_computes(max_computes=1):
actual = xr.polyval(coord=x, coeffs=coeffs)

t0 = xr.date_range(date, periods=1)[0]
offset = (t0 - cftime.DatetimeGregorian(1970, 1, 1)).total_seconds() * 1e9
expected = (
xr.DataArray(
[0, 1e9, 2e9],
dims="x",
coords={"x": xr.date_range(date, freq="1S", periods=3, use_cftime=True)},
)
+ offset
)
xr.testing.assert_allclose(actual, expected)


def test_polyval_degree_dim_checks():
x = (xr.DataArray([1, 2, 3], dims="x"),)
def test_polyval_degree_dim_checks() -> None:
x = xr.DataArray([1, 2, 3], dims="x")
coeffs = xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]})
with pytest.raises(ValueError):
xr.polyval(x, coeffs.drop_vars("degree"))
with pytest.raises(ValueError):
xr.polyval(x, coeffs.assign_coords(degree=coeffs.degree.astype(float)))


@pytest.mark.parametrize(
"use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")]
)
@pytest.mark.parametrize(
"x",
[
pytest.param(xr.DataArray([0, 1, 2], dims="x"), id="simple"),
pytest.param(
xr.DataArray(pd.date_range("1970-01-01", freq="ns", periods=3), dims="x"),
id="datetime",
),
pytest.param(
xr.DataArray(np.array([0, 1, 2], dtype="timedelta64[ns]"), dims="x"),
id="timedelta",
),
],
)
@pytest.mark.parametrize(
"y",
[
pytest.param(xr.DataArray([1, 6, 17], dims="x"), id="1D"),
pytest.param(
xr.DataArray([[1, 6, 17], [34, 57, 86]], dims=("y", "x")), id="2D"
),
],
)
def test_polyfit_polyval_integration(
use_dask: bool, x: xr.DataArray, y: xr.DataArray
) -> None:
y.coords["x"] = x
if use_dask:
if not has_dask:
pytest.skip("requires dask")
y = y.chunk({"x": 2})

fit = y.polyfit(dim="x", deg=2)
evaluated = xr.polyval(y.x, fit.polyfit_coefficients)
expected = y.transpose(*evaluated.dims)
xr.testing.assert_allclose(evaluated.variable, expected.variable)


@pytest.mark.parametrize("use_dask", [False, True])
@pytest.mark.parametrize(
"a, b, ae, be, dim, axis",
Expand Down

0 comments on commit 4c92d52

Please sign in to comment.