From 4c92d52fffc17ecee067215dce27f554904c9bec Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 31 May 2022 19:16:03 +0200 Subject: [PATCH] CFTime support for polyval (#6624) Co-authored-by: dcherian Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian --- xarray/core/computation.py | 31 +++++++--- xarray/core/duck_array_ops.py | 2 +- xarray/tests/test_computation.py | 97 ++++++++++++++++++++++++++++++-- 3 files changed, 116 insertions(+), 14 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ad1f937a4d0..de7934dafdd 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -31,7 +31,7 @@ from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array from .types import T_DataArray -from .utils import is_dict_like +from .utils import is_dict_like, is_scalar from .variable import Variable if TYPE_CHECKING: @@ -1887,6 +1887,15 @@ def polyval(coord: Dataset, coeffs: Dataset, degree_dim: Hashable) -> Dataset: ... +@overload +def polyval( + coord: Dataset | DataArray, + coeffs: Dataset | DataArray, + degree_dim: Hashable = "degree", +) -> Dataset | DataArray: + ... + + def polyval( coord: Dataset | DataArray, coeffs: Dataset | DataArray, @@ -1953,15 +1962,21 @@ def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray: """ from .dataset import Dataset + def _cfoffset(x: DataArray) -> Any: + scalar = x.compute().data[0] + if not is_scalar(scalar): + # we do not get a scalar back on dask == 2021.04.1 + scalar = scalar.item() + return type(scalar)(1970, 1, 1) + def to_floatable(x: DataArray) -> DataArray: - if x.dtype.kind == "M": - # datetimes + if x.dtype.kind in "MO": + # datetimes (CFIndexes are object type) + offset = ( + np.datetime64("1970-01-01") if x.dtype.kind == "M" else _cfoffset(x) + ) return x.copy( - data=datetime_to_numeric( - x.data, - offset=np.datetime64("1970-01-01"), - datetime_unit="ns", - ), + data=datetime_to_numeric(x.data, offset=offset, datetime_unit="ns"), ) elif x.dtype.kind == "m": # timedeltas diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e5a659a9ec9..5f09abdbcfd 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -435,7 +435,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float): # This map_blocks call is for backwards compatibility. # dask == 2021.04.1 does not support subtracting object arrays # which is required for cftime - if is_duck_dask_array(array) and np.issubdtype(array.dtype, np.object): + if is_duck_dask_array(array) and np.issubdtype(array.dtype, object): array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta) else: array = array - offset diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index ec8b5a5bc7c..1eaa772206e 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -26,7 +26,13 @@ from xarray.core.pycompat import dask_version from xarray.core.types import T_Xarray -from . import has_dask, raise_if_dask_computes, requires_dask +from . import ( + has_cftime, + has_dask, + raise_if_dask_computes, + requires_cftime, + requires_dask, +) def assert_identical(a, b): @@ -1936,7 +1942,9 @@ def test_where_attrs() -> None: assert actual.attrs == {} -@pytest.mark.parametrize("use_dask", [False, True]) +@pytest.mark.parametrize( + "use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")] +) @pytest.mark.parametrize( ["x", "coeffs", "expected"], [ @@ -2031,13 +2039,51 @@ def test_polyval( pytest.skip("requires dask") coeffs = coeffs.chunk({"degree": 2}) x = x.chunk({"x": 2}) + with raise_if_dask_computes(): - actual = xr.polyval(coord=x, coeffs=coeffs) # type: ignore + actual = xr.polyval(coord=x, coeffs=coeffs) + + xr.testing.assert_allclose(actual, expected) + + +@requires_cftime +@pytest.mark.parametrize( + "use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")] +) +@pytest.mark.parametrize("date", ["1970-01-01", "0753-04-21"]) +def test_polyval_cftime(use_dask: bool, date: str) -> None: + import cftime + + x = xr.DataArray( + xr.date_range(date, freq="1S", periods=3, use_cftime=True), + dims="x", + ) + coeffs = xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}) + + if use_dask: + if not has_dask: + pytest.skip("requires dask") + coeffs = coeffs.chunk({"degree": 2}) + x = x.chunk({"x": 2}) + + with raise_if_dask_computes(max_computes=1): + actual = xr.polyval(coord=x, coeffs=coeffs) + + t0 = xr.date_range(date, periods=1)[0] + offset = (t0 - cftime.DatetimeGregorian(1970, 1, 1)).total_seconds() * 1e9 + expected = ( + xr.DataArray( + [0, 1e9, 2e9], + dims="x", + coords={"x": xr.date_range(date, freq="1S", periods=3, use_cftime=True)}, + ) + + offset + ) xr.testing.assert_allclose(actual, expected) -def test_polyval_degree_dim_checks(): - x = (xr.DataArray([1, 2, 3], dims="x"),) +def test_polyval_degree_dim_checks() -> None: + x = xr.DataArray([1, 2, 3], dims="x") coeffs = xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]}) with pytest.raises(ValueError): xr.polyval(x, coeffs.drop_vars("degree")) @@ -2045,6 +2091,47 @@ def test_polyval_degree_dim_checks(): xr.polyval(x, coeffs.assign_coords(degree=coeffs.degree.astype(float))) +@pytest.mark.parametrize( + "use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")] +) +@pytest.mark.parametrize( + "x", + [ + pytest.param(xr.DataArray([0, 1, 2], dims="x"), id="simple"), + pytest.param( + xr.DataArray(pd.date_range("1970-01-01", freq="ns", periods=3), dims="x"), + id="datetime", + ), + pytest.param( + xr.DataArray(np.array([0, 1, 2], dtype="timedelta64[ns]"), dims="x"), + id="timedelta", + ), + ], +) +@pytest.mark.parametrize( + "y", + [ + pytest.param(xr.DataArray([1, 6, 17], dims="x"), id="1D"), + pytest.param( + xr.DataArray([[1, 6, 17], [34, 57, 86]], dims=("y", "x")), id="2D" + ), + ], +) +def test_polyfit_polyval_integration( + use_dask: bool, x: xr.DataArray, y: xr.DataArray +) -> None: + y.coords["x"] = x + if use_dask: + if not has_dask: + pytest.skip("requires dask") + y = y.chunk({"x": 2}) + + fit = y.polyfit(dim="x", deg=2) + evaluated = xr.polyval(y.x, fit.polyfit_coefficients) + expected = y.transpose(*evaluated.dims) + xr.testing.assert_allclose(evaluated.variable, expected.variable) + + @pytest.mark.parametrize("use_dask", [False, True]) @pytest.mark.parametrize( "a, b, ae, be, dim, axis",